Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix conflicts
Browse files Browse the repository at this point in the history
volovyks committed Jan 22, 2024
2 parents f89871e + 804e164 commit 82ba4fd
Showing 29 changed files with 2,136 additions and 971 deletions.
802 changes: 523 additions & 279 deletions Cargo.lock

Large diffs are not rendered by default.

160 changes: 68 additions & 92 deletions contract/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,39 @@
pub mod primitives;

use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize};
use near_sdk::collections::LookupMap;
use near_sdk::serde::{Deserialize, Serialize};
use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, Promise, PromiseOrValue, PublicKey};
use primitives::{CandidateInfo, Candidates, ParticipantInfo, Participants, PkVotes, Votes};
use std::collections::{BTreeMap, HashSet};

type ParticipantId = u32;

pub mod hpke {
pub type PublicKey = [u8; 32];
}

#[derive(
Serialize,
Deserialize,
BorshDeserialize,
BorshSerialize,
Clone,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Debug,
)]
pub struct ParticipantInfo {
pub id: ParticipantId,
pub account_id: AccountId,
pub url: String,
/// The public key used for encrypting messages.
pub cipher_pk: hpke::PublicKey,
/// The public key used for verifying messages.
pub sign_pk: PublicKey,
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
pub struct InitializingContractState {
pub participants: BTreeMap<AccountId, ParticipantInfo>,
pub participants: Participants,
pub threshold: usize,
pub pk_votes: BTreeMap<PublicKey, HashSet<ParticipantId>>,
pub pk_votes: PkVotes,
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
pub struct RunningContractState {
pub epoch: u64,
// TODO: why is this account id for participants instead of participant id?
pub participants: BTreeMap<AccountId, ParticipantInfo>,
pub participants: Participants,
pub threshold: usize,
pub public_key: PublicKey,
pub candidates: BTreeMap<ParticipantId, ParticipantInfo>,
pub join_votes: BTreeMap<ParticipantId, HashSet<ParticipantId>>,
pub leave_votes: BTreeMap<ParticipantId, HashSet<ParticipantId>>,
pub candidates: Candidates,
pub join_votes: Votes,
pub leave_votes: Votes,
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
pub struct ResharingContractState {
pub old_epoch: u64,
pub old_participants: BTreeMap<AccountId, ParticipantInfo>,
pub old_participants: Participants,
// TODO: only store diff to save on storage
pub new_participants: BTreeMap<AccountId, ParticipantInfo>,
pub new_participants: Participants,
pub threshold: usize,
pub public_key: PublicKey,
pub finished_votes: HashSet<ParticipantId>,
pub finished_votes: HashSet<AccountId>,
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
@@ -84,9 +57,9 @@ impl MpcContract {
pub fn init(threshold: usize, participants: BTreeMap<AccountId, ParticipantInfo>) -> Self {
MpcContract {
protocol_state: ProtocolContractState::Initializing(InitializingContractState {
participants,
participants: Participants { participants },
threshold,
pk_votes: BTreeMap::new(),
pk_votes: PkVotes::new(),
}),
pending_requests: LookupMap::new(b"m"),
}
@@ -98,9 +71,8 @@ impl MpcContract {

pub fn join(
&mut self,
participant_id: ParticipantId,
url: String,
cipher_pk: hpke::PublicKey,
cipher_pk: primitives::hpke::PublicKey,
sign_pk: PublicKey,
) {
match &mut self.protocol_state {
@@ -109,15 +81,14 @@ impl MpcContract {
candidates,
..
}) => {
let account_id = env::signer_account_id();
if participants.contains_key(&account_id) {
let signer_account_id = env::signer_account_id();
if participants.contains_key(&signer_account_id) {
env::panic_str("this participant is already in the participant set");
}
candidates.insert(
participant_id,
ParticipantInfo {
id: participant_id,
account_id,
signer_account_id.clone(),
CandidateInfo {
account_id: signer_account_id,
url,
cipher_pk,
sign_pk,
@@ -128,7 +99,7 @@ impl MpcContract {
}
}

pub fn vote_join(&mut self, participant: ParticipantId) -> bool {
pub fn vote_join(&mut self, candidate_account_id: AccountId) -> bool {
match &mut self.protocol_state {
ProtocolContractState::Running(RunningContractState {
epoch,
@@ -139,19 +110,19 @@ impl MpcContract {
join_votes,
..
}) => {
let voting_participant = participants
.get(&env::signer_account_id())
.unwrap_or_else(|| {
env::panic_str("calling account is not in the participant set")
});
let candidate = candidates
.get(&participant)
let signer_account_id = env::signer_account_id();
if !participants.contains_key(&signer_account_id) {
env::panic_str("calling account is not in the participant set");
}
let candidate_info = candidates
.get(&candidate_account_id)
.unwrap_or_else(|| env::panic_str("candidate is not registered"));
let voted = join_votes.entry(participant).or_default();
voted.insert(voting_participant.id);
let voted = join_votes.entry(candidate_account_id.clone());
voted.insert(signer_account_id);
if voted.len() >= *threshold {
let mut new_participants = participants.clone();
new_participants.insert(candidate.account_id.clone(), candidate.clone());
new_participants
.insert(candidate_account_id.clone(), candidate_info.clone().into());
self.protocol_state =
ProtocolContractState::Resharing(ResharingContractState {
old_epoch: *epoch,
@@ -170,30 +141,28 @@ impl MpcContract {
}
}

pub fn vote_leave(&mut self, participant: ParticipantId) -> bool {
pub fn vote_leave(&mut self, acc_id_to_leave: AccountId) -> bool {
match &mut self.protocol_state {
ProtocolContractState::Running(RunningContractState {
epoch,
participants,
threshold,
public_key,
candidates,
leave_votes,
..
}) => {
let voting_participant = participants
.get(&env::signer_account_id())
.unwrap_or_else(|| {
env::panic_str("calling account is not in the participant set")
});
let candidate = candidates
.get(&participant)
.unwrap_or_else(|| env::panic_str("candidate is not registered"));
let voted = leave_votes.entry(participant).or_default();
voted.insert(voting_participant.id);
let signer_account_id = env::signer_account_id();
if !participants.contains_key(&signer_account_id) {
env::panic_str("calling account is not in the participant set");
}
if !participants.contains_key(&acc_id_to_leave) {
env::panic_str("account to leave is not in the participant set");
}
let voted = leave_votes.entry(acc_id_to_leave.clone());
voted.insert(signer_account_id);
if voted.len() >= *threshold {
let mut new_participants = participants.clone();
new_participants.remove(&candidate.account_id);
new_participants.remove(&acc_id_to_leave);
self.protocol_state =
ProtocolContractState::Resharing(ResharingContractState {
old_epoch: *epoch,
@@ -219,22 +188,21 @@ impl MpcContract {
threshold,
pk_votes,
}) => {
let voting_participant = participants
.get(&env::signer_account_id())
.unwrap_or_else(|| {
env::panic_str("calling account is not in the participant set")
});
let voted = pk_votes.entry(public_key.clone()).or_default();
voted.insert(voting_participant.id);
let signer_account_id = env::signer_account_id();
if !participants.contains_key(&signer_account_id) {
env::panic_str("calling account is not in the participant set");
}
let voted = pk_votes.entry(public_key.clone());
voted.insert(signer_account_id);
if voted.len() >= *threshold {
self.protocol_state = ProtocolContractState::Running(RunningContractState {
epoch: 0,
participants: participants.clone(),
threshold: *threshold,
public_key,
candidates: BTreeMap::new(),
join_votes: BTreeMap::new(),
leave_votes: BTreeMap::new(),
candidates: Candidates::new(),
join_votes: Votes::new(),
leave_votes: Votes::new(),
});
true
} else {
@@ -260,21 +228,20 @@ impl MpcContract {
if *old_epoch + 1 != epoch {
env::panic_str("mismatched epochs");
}
let voting_participant = old_participants
.get(&env::signer_account_id())
.unwrap_or_else(|| {
env::panic_str("calling account is not in the old participant set")
});
finished_votes.insert(voting_participant.id);
let signer_account_id = env::signer_account_id();
if !old_participants.contains_key(&signer_account_id) {
env::panic_str("calling account is not in the old participant set");
}
finished_votes.insert(signer_account_id);
if finished_votes.len() >= *threshold {
self.protocol_state = ProtocolContractState::Running(RunningContractState {
epoch: *old_epoch + 1,
participants: new_participants.clone(),
threshold: *threshold,
public_key: public_key.clone(),
candidates: BTreeMap::new(),
join_votes: BTreeMap::new(),
leave_votes: BTreeMap::new(),
candidates: Candidates::new(),
join_votes: Votes::new(),
leave_votes: Votes::new(),
});
true
} else {
@@ -337,4 +304,13 @@ impl MpcContract {
pending_requests: LookupMap::new(b"m"),
}
}

/// This is the root public key combined from all the public keys of the participants.
pub fn public_key(&self) -> PublicKey {
match &self.protocol_state {
ProtocolContractState::Running(state) => state.public_key.clone(),
ProtocolContractState::Resharing(state) => state.public_key.clone(),
_ => env::panic_str("public key not available (protocol is not running or resharing)"),
}
}
}
199 changes: 199 additions & 0 deletions contract/src/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize};
use near_sdk::serde::{Deserialize, Serialize};
use near_sdk::{AccountId, PublicKey};
use std::collections::{BTreeMap, HashSet};

pub mod hpke {
pub type PublicKey = [u8; 32];
}

#[derive(
Serialize,
Deserialize,
BorshDeserialize,
BorshSerialize,
Clone,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Debug,
)]
pub struct ParticipantInfo {
pub account_id: AccountId,
pub url: String,
/// The public key used for encrypting messages.
pub cipher_pk: hpke::PublicKey,
/// The public key used for verifying messages.
pub sign_pk: PublicKey,
}

impl From<CandidateInfo> for ParticipantInfo {
fn from(candidate_info: CandidateInfo) -> Self {
ParticipantInfo {
account_id: candidate_info.account_id,
url: candidate_info.url,
cipher_pk: candidate_info.cipher_pk,
sign_pk: candidate_info.sign_pk,
}
}
}

#[derive(
Serialize,
Deserialize,
BorshDeserialize,
BorshSerialize,
Clone,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Debug,
)]
pub struct CandidateInfo {
pub account_id: AccountId,
pub url: String,
/// The public key used for encrypting messages.
pub cipher_pk: hpke::PublicKey,
/// The public key used for verifying messages.
pub sign_pk: PublicKey,
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)]
pub struct Participants {
pub participants: BTreeMap<AccountId, ParticipantInfo>,
}

impl Default for Participants {
fn default() -> Self {
Self::new()
}
}

impl Participants {
pub fn new() -> Self {
Participants {
participants: BTreeMap::new(),
}
}

pub fn contains_key(&self, account_id: &AccountId) -> bool {
self.participants.contains_key(account_id)
}

pub fn insert(&mut self, account_id: AccountId, participant_info: ParticipantInfo) {
self.participants.insert(account_id, participant_info);
}

pub fn remove(&mut self, account_id: &AccountId) {
self.participants.remove(account_id);
}

pub fn get(&self, account_id: &AccountId) -> Option<&ParticipantInfo> {
self.participants.get(account_id)
}

pub fn iter(&self) -> impl Iterator<Item = (&AccountId, &ParticipantInfo)> {
self.participants.iter()
}

pub fn keys(&self) -> impl Iterator<Item = &AccountId> {
self.participants.keys()
}

pub fn len(&self) -> usize {
self.participants.len()
}

pub fn is_empty(&self) -> bool {
self.participants.is_empty()
}
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)]
pub struct Candidates {
pub candidates: BTreeMap<AccountId, CandidateInfo>,
}

impl Default for Candidates {
fn default() -> Self {
Self::new()
}
}

impl Candidates {
pub fn new() -> Self {
Candidates {
candidates: BTreeMap::new(),
}
}

pub fn contains_key(&self, account_id: &AccountId) -> bool {
self.candidates.contains_key(account_id)
}

pub fn insert(&mut self, account_id: AccountId, candidate: CandidateInfo) {
self.candidates.insert(account_id, candidate);
}

pub fn remove(&mut self, account_id: &AccountId) {
self.candidates.remove(account_id);
}

pub fn get(&self, account_id: &AccountId) -> Option<&CandidateInfo> {
self.candidates.get(account_id)
}

pub fn iter(&self) -> impl Iterator<Item = (&AccountId, &CandidateInfo)> {
self.candidates.iter()
}
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
pub struct Votes {
pub votes: BTreeMap<AccountId, HashSet<AccountId>>,
}

impl Default for Votes {
fn default() -> Self {
Self::new()
}
}

impl Votes {
pub fn new() -> Self {
Votes {
votes: BTreeMap::new(),
}
}

pub fn entry(&mut self, account_id: AccountId) -> &mut HashSet<AccountId> {
self.votes.entry(account_id).or_default()
}
}

#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)]
pub struct PkVotes {
pub votes: BTreeMap<PublicKey, HashSet<AccountId>>,
}

impl Default for PkVotes {
fn default() -> Self {
Self::new()
}
}

impl PkVotes {
pub fn new() -> Self {
PkVotes {
votes: BTreeMap::new(),
}
}

pub fn entry(&mut self, public_key: PublicKey) -> &mut HashSet<AccountId> {
self.votes.entry(public_key).or_default()
}
}
2 changes: 1 addition & 1 deletion infra/mpc-recovery-testnet/main.tf
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ locals {

workspace = {
near_rpc = "https://rpc.testnet.near.org"
near_root_account = "near"
near_root_account = "testnet"
}
}

133 changes: 133 additions & 0 deletions infra/multichain-prod/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
terraform {
backend "gcs" {
bucket = "multichain-terraform-prod"
prefix = "state/multichain"
}

required_providers {
google = {
source = "hashicorp/google"
version = "4.73.0"
}
}
}

locals {
credentials = var.credentials != null ? var.credentials : file(var.credentials_file)
client_email = jsondecode(local.credentials).client_email
client_id = jsondecode(local.credentials).client_id

workspace = {
near_rpc = "https://rpc.mainnet.near.org"
}
}

data "external" "git_checkout" {
program = ["${path.module}/../scripts/get_sha.sh"]
}

provider "google" {
credentials = local.credentials

project = var.project
region = var.region
zone = var.zone
}

/*
* Create brand new service account with basic IAM
*/
resource "google_service_account" "service_account" {
account_id = "multichain-mainnet"
display_name = "Multichain mainnet Account"
}

resource "google_service_account_iam_binding" "serivce-account-iam" {
service_account_id = google_service_account.service_account.name
role = "roles/iam.serviceAccountUser"

members = [
"serviceAccount:${local.client_email}",
]
}

/*
* Ensure service account has access to Secret Manager variables
*/
resource "google_secret_manager_secret_iam_member" "account_sk_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].account_sk_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "cipher_sk_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].cipher_sk_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "aws_access_key_secret_access" {
secret_id = var.aws_access_key_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "aws_secret_key_secret_access" {
secret_id = var.aws_secret_key_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "sk_share_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].sk_share_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "sk_share_secret_manager" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].sk_share_secret_id
role = "roles/secretmanager.secretVersionManager"
member = "serviceAccount:${google_service_account.service_account.email}"
}

module "node" {
count = length(var.node_configs)
source = "../modules/multichain"

service_name = "multichain-mainnet-${count.index}"
project = var.project
region = var.region
service_account_email = google_service_account.service_account.email
docker_image = var.docker_image

node_id = count.index
near_rpc = local.workspace.near_rpc
mpc_contract_id = var.mpc_contract_id
account = var.node_configs[count.index].account
cipher_pk = var.node_configs[count.index].cipher_pk
indexer_options = var.indexer_options
my_address = var.node_configs[count.index].address

account_sk_secret_id = var.node_configs[count.index].account_sk_secret_id
cipher_sk_secret_id = var.node_configs[count.index].cipher_sk_secret_id
aws_access_key_secret_id = var.aws_access_key_secret_id
aws_secret_key_secret_id = var.aws_secret_key_secret_id
sk_share_secret_id = var.node_configs[count.index].sk_share_secret_id

depends_on = [
google_secret_manager_secret_iam_member.account_sk_secret_access,
google_secret_manager_secret_iam_member.cipher_sk_secret_access,
google_secret_manager_secret_iam_member.aws_access_key_secret_access,
google_secret_manager_secret_iam_member.aws_secret_key_secret_access,
google_secret_manager_secret_iam_member.sk_share_secret_access,
google_secret_manager_secret_iam_member.sk_share_secret_manager
]
}
57 changes: 57 additions & 0 deletions infra/multichain-prod/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
variable "env" {
}

variable "project" {
}

variable "credentials_file" {
default = null
}

variable "credentials" {
default = null
}

variable "region" {
default = "us-east1"
}

variable "zone" {
default = "us-east1-c"
}

variable "docker_image" {
type = string
}

variable "mpc_contract_id" {
type = string
}

variable "indexer_options" {
type = object({
s3_bucket = string
s3_region = string
s3_url = string
start_block_height = number
})
}

variable "node_configs" {
type = list(object({
account = string
cipher_pk = string
address = string
account_sk_secret_id = string
cipher_sk_secret_id = string
sk_share_secret_id = string
}))
}

variable "aws_access_key_secret_id" {
type = string
}

variable "aws_secret_key_secret_id" {
type = string
}
133 changes: 133 additions & 0 deletions infra/multichain-testnet/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
terraform {
backend "gcs" {
bucket = "multichain-terraform-prod"
prefix = "state/multichain"
}

required_providers {
google = {
source = "hashicorp/google"
version = "4.73.0"
}
}
}

locals {
credentials = var.credentials != null ? var.credentials : file(var.credentials_file)
client_email = jsondecode(local.credentials).client_email
client_id = jsondecode(local.credentials).client_id

workspace = {
near_rpc = "https://rpc.testnet.near.org"
}
}

data "external" "git_checkout" {
program = ["${path.module}/../scripts/get_sha.sh"]
}

provider "google" {
credentials = local.credentials

project = var.project
region = var.region
zone = var.zone
}

/*
* Create brand new service account with basic IAM
*/
resource "google_service_account" "service_account" {
account_id = "multichain-testnet"
display_name = "Multichain testnet Account"
}

resource "google_service_account_iam_binding" "serivce-account-iam" {
service_account_id = google_service_account.service_account.name
role = "roles/iam.serviceAccountUser"

members = [
"serviceAccount:${local.client_email}",
]
}

/*
* Ensure service account has access to Secret Manager variables
*/
resource "google_secret_manager_secret_iam_member" "account_sk_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].account_sk_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "cipher_sk_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].cipher_sk_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "aws_access_key_secret_access" {
secret_id = var.aws_access_key_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "aws_secret_key_secret_access" {
secret_id = var.aws_secret_key_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "sk_share_secret_access" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].sk_share_secret_id
role = "roles/secretmanager.secretAccessor"
member = "serviceAccount:${google_service_account.service_account.email}"
}

resource "google_secret_manager_secret_iam_member" "sk_share_secret_manager" {
count = length(var.node_configs)

secret_id = var.node_configs[count.index].sk_share_secret_id
role = "roles/secretmanager.secretVersionManager"
member = "serviceAccount:${google_service_account.service_account.email}"
}

module "node" {
count = length(var.node_configs)
source = "../modules/multichain"

service_name = "multichain-testnet-${count.index}"
project = var.project
region = var.region
service_account_email = google_service_account.service_account.email
docker_image = var.docker_image

node_id = count.index
near_rpc = local.workspace.near_rpc
mpc_contract_id = var.mpc_contract_id
account = var.node_configs[count.index].account
cipher_pk = var.node_configs[count.index].cipher_pk
indexer_options = var.indexer_options
my_address = var.node_configs[count.index].address

account_sk_secret_id = var.node_configs[count.index].account_sk_secret_id
cipher_sk_secret_id = var.node_configs[count.index].cipher_sk_secret_id
aws_access_key_secret_id = var.aws_access_key_secret_id
aws_secret_key_secret_id = var.aws_secret_key_secret_id
sk_share_secret_id = var.node_configs[count.index].sk_share_secret_id

depends_on = [
google_secret_manager_secret_iam_member.account_sk_secret_access,
google_secret_manager_secret_iam_member.cipher_sk_secret_access,
google_secret_manager_secret_iam_member.aws_access_key_secret_access,
google_secret_manager_secret_iam_member.aws_secret_key_secret_access,
google_secret_manager_secret_iam_member.sk_share_secret_access,
google_secret_manager_secret_iam_member.sk_share_secret_manager
]
}
57 changes: 57 additions & 0 deletions infra/multichain-testnet/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
variable "env" {
}

variable "project" {
}

variable "credentials_file" {
default = null
}

variable "credentials" {
default = null
}

variable "region" {
default = "us-east1"
}

variable "zone" {
default = "us-east1-c"
}

variable "docker_image" {
type = string
}

variable "mpc_contract_id" {
type = string
}

variable "indexer_options" {
type = object({
s3_bucket = string
s3_region = string
s3_url = string
start_block_height = number
})
}

variable "node_configs" {
type = list(object({
account = string
cipher_pk = string
address = string
account_sk_secret_id = string
cipher_sk_secret_id = string
sk_share_secret_id = string
}))
}

variable "aws_access_key_secret_id" {
type = string
}

variable "aws_secret_key_secret_id" {
type = string
}
14 changes: 8 additions & 6 deletions integration-tests/src/multichain/containers.rs
Original file line number Diff line number Diff line change
@@ -32,17 +32,15 @@ impl<'a> Node<'a> {

pub async fn run(
ctx: &super::Context<'a>,
node_id: u32,
account: &AccountId,
account_id: &AccountId,
account_sk: &near_workspaces::types::SecretKey,
) -> anyhow::Result<Node<'a>> {
tracing::info!(node_id, "running node container");
tracing::info!("running node container, account_id={}", account_id);
let (cipher_sk, cipher_pk) = hpke::generate();
let args = mpc_recovery_node::cli::Cli::Start {
node_id: node_id.into(),
near_rpc: ctx.lake_indexer.rpc_host_address.clone(),
mpc_contract_id: ctx.mpc_contract.id().clone(),
account: account.clone(),
account_id: account_id.clone(),
account_sk: account_sk.to_string().parse()?,
web_port: Self::CONTAINER_PORT,
cipher_pk: hex::encode(cipher_pk.to_bytes()),
@@ -80,7 +78,11 @@ impl<'a> Node<'a> {
});

let full_address = format!("http://{ip_address}:{}", Self::CONTAINER_PORT);
tracing::info!(node_id, full_address, "node container is running");
tracing::info!(
full_address,
"node container is running, account_id={}",
account_id
);
Ok(Node {
container,
address: full_address,
16 changes: 6 additions & 10 deletions integration-tests/src/multichain/local.rs
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@ use near_workspaces::AccountId;
#[allow(dead_code)]
pub struct Node {
pub address: String,
node_id: usize,
account: AccountId,
account_id: AccountId,
pub account_sk: near_workspaces::types::SecretKey,
pub cipher_pk: hpke::PublicKey,
cipher_sk: hpke::SecretKey,
@@ -20,17 +19,15 @@ pub struct Node {
impl Node {
pub async fn run(
ctx: &super::Context<'_>,
node_id: u32,
account: &AccountId,
account_id: &AccountId,
account_sk: &near_workspaces::types::SecretKey,
) -> anyhow::Result<Self> {
let web_port = util::pick_unused_port().await?;
let (cipher_sk, cipher_pk) = hpke::generate();
let cli = mpc_recovery_node::cli::Cli::Start {
node_id: node_id.into(),
near_rpc: ctx.lake_indexer.rpc_host_address.clone(),
mpc_contract_id: ctx.mpc_contract.id().clone(),
account: account.clone(),
account_id: account_id.clone(),
account_sk: account_sk.to_string().parse()?,
web_port,
cipher_pk: hex::encode(cipher_pk.to_bytes()),
@@ -48,17 +45,16 @@ impl Node {
},
};

let mpc_node_id = format!("multichain/{node_id}");
let mpc_node_id = format!("multichain/{account_id}", account_id = account_id);
let process = mpc::spawn_multichain(ctx.release, &mpc_node_id, cli)?;
let address = format!("http://127.0.0.1:{web_port}");
tracing::info!("node is starting at {}", address);
util::ping_until_ok(&address, 60).await?;
tracing::info!("node started [node_id={node_id}, {address}]");
tracing::info!("node started [node_account_id={account_id}, {address}]");

Ok(Self {
address,
node_id: node_id as usize,
account: account.clone(),
account_id: account_id.clone(),
account_sk: account_sk.clone(),
cipher_pk,
cipher_sk,
28 changes: 9 additions & 19 deletions integration-tests/src/multichain/mod.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ pub mod local;

use crate::env::containers::DockerClient;
use crate::{initialize_lake_indexer, LakeIndexerCtx};
use mpc_contract::ParticipantInfo;
use mpc_contract::primitives::ParticipantInfo;
use near_workspaces::network::Sandbox;
use near_workspaces::{AccountId, Contract, Worker};
use serde_json::json;
@@ -50,17 +50,16 @@ impl Nodes<'_> {

pub async fn add_node(
&mut self,
node_id: u32,
account: &AccountId,
account_sk: &near_workspaces::types::SecretKey,
) -> anyhow::Result<()> {
tracing::info!(%account, "adding one more node");
match self {
Nodes::Local { ctx, nodes } => {
nodes.push(local::Node::run(ctx, node_id, account, account_sk).await?)
nodes.push(local::Node::run(ctx, account, account_sk).await?)
}
Nodes::Docker { ctx, nodes } => {
nodes.push(containers::Node::run(ctx, node_id, account, account_sk).await?)
nodes.push(containers::Node::run(ctx, account, account_sk).await?)
}
}

@@ -124,8 +123,8 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
let mut node_futures = Vec::new();
for (i, account) in accounts.iter().enumerate() {
let node = containers::Node::run(&ctx, i as u32, account.id(), account.secret_key());
for account in &accounts {
let node = containers::Node::run(&ctx, account.id(), account.secret_key());
node_futures.push(node);
}
let nodes = futures::future::join_all(node_futures)
@@ -135,13 +134,11 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul
let participants: HashMap<AccountId, ParticipantInfo> = accounts
.iter()
.cloned()
.enumerate()
.zip(&nodes)
.map(|((i, account), node)| {
.map(|(account, node)| {
(
account.id().clone(),
ParticipantInfo {
id: i as u32,
account_id: account.id().to_string().parse().unwrap(),
url: node.address.clone(),
cipher_pk: node.cipher_pk.to_bytes(),
@@ -171,13 +168,8 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result<
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
let mut node_futures = Vec::with_capacity(nodes);
for (i, account) in accounts.iter().enumerate().take(nodes) {
node_futures.push(local::Node::run(
&ctx,
i as u32,
account.id(),
account.secret_key(),
));
for account in accounts.iter().take(nodes) {
node_futures.push(local::Node::run(&ctx, account.id(), account.secret_key()));
}
let nodes = futures::future::join_all(node_futures)
.await
@@ -186,13 +178,11 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result<
let participants: HashMap<AccountId, ParticipantInfo> = accounts
.iter()
.cloned()
.enumerate()
.zip(&nodes)
.map(|((i, account), node)| {
.map(|(account, node)| {
(
account.id().clone(),
ParticipantInfo {
id: i as u32,
account_id: account.id().to_string().parse().unwrap(),
url: node.address.clone(),
cipher_pk: node.cipher_pk.to_bytes(),
2 changes: 1 addition & 1 deletion integration-tests/tests/multichain/mod.rs
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ async fn test_multichain_reshare() -> anyhow::Result<()> {

let account = ctx.nodes.ctx().worker.dev_create_account().await?;
ctx.nodes
.add_node(3, account.id(), account.secret_key())
.add_node(account.id(), account.secret_key())
.await?;

// Wait for network to complete key reshare
4 changes: 4 additions & 0 deletions node/Cargo.toml
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ clap = { version = "4.2", features = ["derive", "env"] }
google-secretmanager1 = "5"
hex = "0.4.3"
hkdf = "0.12.4"
highway = "1.1.0"
k256 = { version = "0.13.1", features = ["sha256", "ecdsa", "serde"] }
local-ip-address = "0.5.4"
rand = "0.8"
@@ -45,3 +46,6 @@ near-sdk = "5.0.0-alpha.1"

mpc-contract = { path = "../contract" }
mpc-keys = { path = "../keys" }

[dev-dependencies]
itertools = "0.12.0"
29 changes: 8 additions & 21 deletions node/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::protocol::{MpcSignProtocol, SignQueue};
use crate::{indexer, storage, web};
use cait_sith::protocol::Participant;
use clap::Parser;
use local_ip_address::local_ip;
use near_crypto::{InMemorySigner, SecretKey};
@@ -16,9 +15,6 @@ use mpc_keys::hpke;
#[derive(Parser, Debug)]
pub enum Cli {
Start {
/// Node ID
#[arg(long, value_parser = parse_participant, env("MPC_RECOVERY_NODE_ID"))]
node_id: Participant,
/// NEAR RPC address
#[arg(
long,
@@ -30,8 +26,8 @@ pub enum Cli {
#[arg(long, env("MPC_RECOVERY_CONTRACT_ID"))]
mpc_contract_id: AccountId,
/// This node's account id
#[arg(long, env("MPC_RECOVERY_ACCOUNT"))]
account: AccountId,
#[arg(long, env("MPC_RECOVERY_ACCOUNT_ID"))]
account_id: AccountId,
/// This node's account ed25519 secret key
#[arg(long, env("MPC_RECOVERY_ACCOUNT_SK"))]
account_sk: SecretKey,
@@ -57,19 +53,13 @@ pub enum Cli {
},
}

fn parse_participant(arg: &str) -> Result<Participant, std::num::ParseIntError> {
let participant_id: u32 = arg.parse()?;
Ok(participant_id.into())
}

impl Cli {
pub fn into_str_args(self) -> Vec<String> {
match self {
Cli::Start {
node_id,
near_rpc,
account_id,
mpc_contract_id,
account,
account_sk,
web_port,
cipher_pk,
@@ -80,14 +70,12 @@ impl Cli {
} => {
let mut args = vec![
"start".to_string(),
"--node-id".to_string(),
u32::from(node_id).to_string(),
"--near-rpc".to_string(),
near_rpc,
"--mpc-contract-id".to_string(),
mpc_contract_id.to_string(),
"--account".to_string(),
account.to_string(),
"--account-id".to_string(),
account_id.to_string(),
"--account-sk".to_string(),
account_sk.to_string(),
"--web-port".to_string(),
@@ -123,11 +111,10 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> {

match cmd {
Cli::Start {
node_id,
near_rpc,
web_port,
mpc_contract_id,
account,
account_id,
account_sk,
cipher_pk,
cipher_sk,
@@ -156,11 +143,11 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> {
tracing::info!(%my_address, "address detected");
let rpc_client = near_fetch::Client::new(&near_rpc);
tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized");
let signer = InMemorySigner::from_secret_key(account, account_sk);
let signer = InMemorySigner::from_secret_key(account_id.clone(), account_sk);
let (protocol, protocol_state) = MpcSignProtocol::init(
node_id,
my_address,
mpc_contract_id.clone(),
account_id,
rpc_client.clone(),
signer.clone(),
receiver,
13 changes: 9 additions & 4 deletions node/src/http_client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::protocol::contract::primitives::ParticipantInfo;
use crate::protocol::message::SignedMessage;
use crate::protocol::MpcMessage;
use crate::protocol::ParticipantInfo;
use cait_sith::protocol::Participant;
use mpc_keys::hpke;
use near_primitives::types::AccountId;
use reqwest::{Client, IntoUrl};
use std::collections::VecDeque;
use std::str::Utf8Error;
@@ -73,16 +74,20 @@ async fn send_encrypted<U: IntoUrl>(
Retry::spawn(retry_strategy, action).await
}

pub async fn join<U: IntoUrl>(client: &Client, url: U, me: &Participant) -> Result<(), SendError> {
let _span = tracing::info_span!("join_request", ?me);
pub async fn join<U: IntoUrl>(
client: &Client,
url: U,
account_id: &AccountId,
) -> Result<(), SendError> {
let _span = tracing::info_span!("join_request", ?account_id);
let mut url = url.into_url()?;
url.set_path("join");
tracing::debug!(%url, "making http request");
let action = || async {
let response = client
.post(url.clone())
.header("content-type", "application/json")
.json(&me)
.json(&account_id)
.send()
.await
.map_err(SendError::ReqwestClientError)?;
347 changes: 201 additions & 146 deletions node/src/protocol/consensus.rs

Large diffs are not rendered by default.

200 changes: 0 additions & 200 deletions node/src/protocol/contract.rs

This file was deleted.

131 changes: 131 additions & 0 deletions node/src/protocol/contract/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
pub mod primitives;

use crate::types::PublicKey;
use crate::util::NearPublicKeyExt;
use mpc_contract::ProtocolContractState;
use near_primitives::types::AccountId;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, str::FromStr};

use self::primitives::{Candidates, Participants, PkVotes, Votes};

#[derive(Serialize, Deserialize, Debug)]
pub struct InitializingContractState {
pub participants: Participants,
pub threshold: usize,
pub pk_votes: PkVotes,
}

impl From<mpc_contract::InitializingContractState> for InitializingContractState {
fn from(value: mpc_contract::InitializingContractState) -> Self {
InitializingContractState {
participants: value.participants.into(),
threshold: value.threshold,
pk_votes: value.pk_votes.into(),
}
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct RunningContractState {
pub epoch: u64,
pub participants: Participants,
pub threshold: usize,
pub public_key: PublicKey,
pub candidates: Candidates,
pub join_votes: Votes,
pub leave_votes: Votes,
}

impl From<mpc_contract::RunningContractState> for RunningContractState {
fn from(value: mpc_contract::RunningContractState) -> Self {
RunningContractState {
epoch: value.epoch,
participants: value.participants.into(),
threshold: value.threshold,
public_key: value.public_key.into_affine_point(),
candidates: value.candidates.into(),
join_votes: value.join_votes.into(),
leave_votes: value.leave_votes.into(),
}
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ResharingContractState {
pub old_epoch: u64,
pub old_participants: Participants,
pub new_participants: Participants,
pub threshold: usize,
pub public_key: PublicKey,
pub finished_votes: HashSet<AccountId>,
}

impl From<mpc_contract::ResharingContractState> for ResharingContractState {
fn from(contract_state: mpc_contract::ResharingContractState) -> Self {
ResharingContractState {
old_epoch: contract_state.old_epoch,
old_participants: contract_state.old_participants.into(),
new_participants: contract_state.new_participants.into(),
threshold: contract_state.threshold,
public_key: contract_state.public_key.into_affine_point(),
finished_votes: contract_state
.finished_votes
.into_iter()
.map(|acc_id| AccountId::from_str(acc_id.as_ref()).unwrap())
.collect(),
}
}
}

#[derive(Debug)]
pub enum ProtocolState {
Initializing(InitializingContractState),
Running(RunningContractState),
Resharing(ResharingContractState),
}

impl ProtocolState {
pub fn participants(&self) -> &Participants {
match self {
ProtocolState::Initializing(InitializingContractState { participants, .. }) => {
participants
}
ProtocolState::Running(RunningContractState { participants, .. }) => participants,
ProtocolState::Resharing(ResharingContractState {
old_participants, ..
}) => old_participants,
}
}

pub fn public_key(&self) -> Option<&PublicKey> {
match self {
ProtocolState::Initializing { .. } => None,
ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key),
ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key),
}
}

pub fn threshold(&self) -> usize {
match self {
ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold,
ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold,
ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold,
}
}
}

impl TryFrom<ProtocolContractState> for ProtocolState {
type Error = ();

fn try_from(value: ProtocolContractState) -> Result<Self, Self::Error> {
match value {
ProtocolContractState::Initializing(state) => {
Ok(ProtocolState::Initializing(state.into()))
}
ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())),
ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())),
ProtocolContractState::NotInitialized => Err(()),
}
}
}
244 changes: 244 additions & 0 deletions node/src/protocol/contract/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use cait_sith::protocol::Participant;
use mpc_keys::hpke;
use near_primitives::{borsh::BorshDeserialize, types::AccountId};
use serde::{Deserialize, Serialize};
use std::{
collections::{BTreeMap, HashSet},
str::FromStr,
};

type ParticipantId = u32;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ParticipantInfo {
pub id: ParticipantId,
pub account_id: AccountId,
pub url: String,
/// The public key used for encrypting messages.
pub cipher_pk: hpke::PublicKey,
/// The public key used for verifying messages.
pub sign_pk: near_crypto::PublicKey,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Participants {
pub participants: BTreeMap<Participant, ParticipantInfo>,
}

impl From<mpc_contract::primitives::Participants> for Participants {
fn from(contract_participants: mpc_contract::primitives::Participants) -> Self {
Participants {
// take position of participant in contract_participants as id for participants
participants: contract_participants
.participants
.into_iter()
.enumerate()
.map(|(participant_id, participant)| {
let contract_participant_info = participant.1;
(
Participant::from(participant_id as ParticipantId),
ParticipantInfo {
id: participant_id as ParticipantId,
account_id: AccountId::from_str(
contract_participant_info.account_id.as_ref(),
)
.unwrap(),
url: contract_participant_info.url,
cipher_pk: hpke::PublicKey::from_bytes(
&contract_participant_info.cipher_pk,
),
sign_pk: BorshDeserialize::try_from_slice(
contract_participant_info.sign_pk.as_bytes(),
)
.unwrap(),
},
)
})
.collect(),
}
}
}

impl IntoIterator for Participants {
type Item = (Participant, ParticipantInfo);
type IntoIter = std::collections::btree_map::IntoIter<Participant, ParticipantInfo>;

fn into_iter(self) -> Self::IntoIter {
self.participants.into_iter()
}
}

impl Participants {
pub fn get(&self, id: &Participant) -> Option<&ParticipantInfo> {
self.participants.get(id)
}

pub fn contains_key(&self, id: &Participant) -> bool {
self.participants.contains_key(id)
}

pub fn keys(&self) -> impl Iterator<Item = &Participant> {
self.participants.keys()
}

pub fn iter(&self) -> impl Iterator<Item = (&Participant, &ParticipantInfo)> {
self.participants.iter()
}

pub fn find_participant(&self, account_id: &AccountId) -> Option<Participant> {
self.participants
.iter()
.find(|(_, participant_info)| participant_info.account_id == *account_id)
.map(|(participant, _)| *participant)
}

pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> {
self.participants
.values()
.find(|participant_info| participant_info.account_id == *account_id)
}

pub fn contains_account_id(&self, account_id: &AccountId) -> bool {
self.participants
.values()
.any(|participant_info| participant_info.account_id == *account_id)
}

pub fn account_ids(&self) -> Vec<AccountId> {
self.participants
.values()
.map(|participant_info| participant_info.account_id.clone())
.collect()
}
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CandidateInfo {
pub account_id: AccountId,
pub url: String,
/// The public key used for encrypting messages.
pub cipher_pk: hpke::PublicKey,
/// The public key used for verifying messages.
pub sign_pk: near_crypto::PublicKey,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Candidates {
pub candidates: BTreeMap<AccountId, CandidateInfo>,
}

impl Candidates {
pub fn get(&self, id: &AccountId) -> Option<&CandidateInfo> {
self.candidates.get(id)
}

pub fn contains_key(&self, id: &AccountId) -> bool {
self.candidates.contains_key(id)
}

pub fn keys(&self) -> impl Iterator<Item = &AccountId> {
self.candidates.keys()
}

pub fn iter(&self) -> impl Iterator<Item = (&AccountId, &CandidateInfo)> {
self.candidates.iter()
}

pub fn find_candidate(&self, account_id: &AccountId) -> Option<&CandidateInfo> {
self.candidates.get(account_id)
}
}

impl From<mpc_contract::primitives::Candidates> for Candidates {
fn from(contract_candidates: mpc_contract::primitives::Candidates) -> Self {
Candidates {
candidates: contract_candidates
.candidates
.into_iter()
.map(|(account_id, candidate_info)| {
(
AccountId::from_str(account_id.as_ref()).unwrap(),
CandidateInfo {
account_id: AccountId::from_str(candidate_info.account_id.as_ref())
.unwrap(),
url: candidate_info.url,
cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk),
sign_pk: BorshDeserialize::try_from_slice(
candidate_info.sign_pk.as_bytes(),
)
.unwrap(),
},
)
})
.collect(),
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PkVotes {
pub pk_votes: BTreeMap<near_crypto::PublicKey, HashSet<AccountId>>,
}

impl PkVotes {
pub fn get(&self, id: &near_crypto::PublicKey) -> Option<&HashSet<AccountId>> {
self.pk_votes.get(id)
}
}

impl From<mpc_contract::primitives::PkVotes> for PkVotes {
fn from(contract_votes: mpc_contract::primitives::PkVotes) -> Self {
PkVotes {
pk_votes: contract_votes
.votes
.into_iter()
.map(|(pk, participants)| {
(
near_crypto::PublicKey::SECP256K1(
near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(),
),
participants
.into_iter()
.map(|acc_id: near_sdk::AccountId| {
AccountId::from_str(acc_id.as_ref()).unwrap()
})
.collect(),
)
})
.collect(),
}
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Votes {
pub votes: BTreeMap<AccountId, HashSet<AccountId>>,
}

impl Votes {
pub fn get(&self, id: &AccountId) -> Option<&HashSet<AccountId>> {
self.votes.get(id)
}
}

impl From<mpc_contract::primitives::Votes> for Votes {
fn from(contract_votes: mpc_contract::primitives::Votes) -> Self {
Votes {
votes: contract_votes
.votes
.into_iter()
.map(|(account_id, participants)| {
(
AccountId::from_str(account_id.as_ref()).unwrap(),
participants
.into_iter()
.map(|acc_id: near_sdk::AccountId| {
AccountId::from_str(acc_id.as_ref()).unwrap()
})
.collect(),
)
})
.collect(),
}
}
}
65 changes: 34 additions & 31 deletions node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
@@ -13,8 +13,9 @@ use mpc_keys::hpke;
use near_crypto::InMemorySigner;
use near_primitives::types::AccountId;

#[async_trait::async_trait]
pub trait CryptographicCtx {
fn me(&self) -> Participant;
async fn me(&self) -> Participant;
fn http_client(&self) -> &reqwest::Client;
fn rpc_client(&self) -> &near_fetch::Client;
fn signer(&self) -> &InMemorySigner;
@@ -69,58 +70,58 @@ impl CryptographicProtocol for GeneratingState {
mut self,
mut ctx: C,
) -> Result<NodeState, CryptographicError> {
tracing::info!("progressing key generation");
tracing::info!("generating: progressing key generation");
let mut protocol = self.protocol.write().await;
loop {
let action = protocol.poke()?;
match action {
Action::Wait => {
drop(protocol);
tracing::debug!("waiting");
tracing::debug!("generating: waiting");
if let Err(err) = self
.messages
.write()
.await
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message");
tracing::warn!(?err, participants = ?self.participants, "generating(wait): failed to send encrypted message");
}

return Ok(NodeState::Generating(self));
}
Action::SendMany(m) => {
tracing::debug!("sending a message to many participants");
tracing::debug!("generating: sending a message to many participants");
let mut messages = self.messages.write().await;
for (p, info) in &self.participants {
if p == &ctx.me() {
for (p, info) in self.participants.iter() {
if p == &ctx.me().await {
// Skip yourself, cait-sith never sends messages to oneself
continue;
}
messages.push(
info.clone(),
MpcMessage::Generating(GeneratingMessage {
from: ctx.me(),
from: ctx.me().await,
data: m.clone(),
}),
);
}
}
Action::SendPrivate(to, m) => {
tracing::debug!("sending a private message to {to:?}");
tracing::debug!("generating: sending a private message to {to:?}");
let info = self.fetch_participant(&to)?;
self.messages.write().await.push(
info.clone(),
MpcMessage::Generating(GeneratingMessage {
from: ctx.me(),
from: ctx.me().await,
data: m.clone(),
}),
);
}
Action::Return(r) => {
tracing::info!(
public_key = hex::encode(r.public_key.to_bytes()),
"successfully completed key generation"
"generating: successfully completed key generation"
);
ctx.secret_storage()
.store(&PersistentNodeData {
@@ -134,10 +135,10 @@ impl CryptographicProtocol for GeneratingState {
.messages
.write()
.await
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message");
tracing::warn!(?err, participants = ?self.participants, "generating(return): failed to send encrypted message");
}
return Ok(NodeState::WaitingForConsensus(WaitingForConsensusState {
epoch: 0,
@@ -163,10 +164,10 @@ impl CryptographicProtocol for WaitingForConsensusState {
.messages
.write()
.await
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, participants = ?self.participants, "waiting: failed to send encrypted message");
tracing::warn!(?err, participants = ?self.participants, "waitingForConsensus: failed to send encrypted message");
}

// Wait for ConsensusProtocol step to advance state
@@ -187,12 +188,12 @@ impl CryptographicProtocol for ResharingState {
match action {
Action::Wait => {
drop(protocol);
tracing::debug!("waiting");
tracing::debug!("resharing: waiting");
if let Err(err) = self
.messages
.write()
.await
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(wait): failed to send encrypted message");
@@ -201,10 +202,10 @@ impl CryptographicProtocol for ResharingState {
return Ok(NodeState::Resharing(self));
}
Action::SendMany(m) => {
tracing::debug!("sending a message to all participants");
tracing::debug!("resharing: sending a message to all participants");
let mut messages = self.messages.write().await;
for (p, info) in &self.new_participants {
if p == &ctx.me() {
for (p, info) in self.new_participants.clone() {
if p == ctx.me().await {
// Skip yourself, cait-sith never sends messages to oneself
continue;
}
@@ -213,35 +214,35 @@ impl CryptographicProtocol for ResharingState {
info.clone(),
MpcMessage::Resharing(ResharingMessage {
epoch: self.old_epoch,
from: ctx.me(),
from: ctx.me().await,
data: m.clone(),
}),
)
}
}
Action::SendPrivate(to, m) => {
tracing::debug!("sending a private message to {to:?}");
tracing::debug!("resharing: sending a private message to {to:?}");
match self.new_participants.get(&to) {
Some(info) => self.messages.write().await.push(
info.clone(),
MpcMessage::Resharing(ResharingMessage {
epoch: self.old_epoch,
from: ctx.me(),
from: ctx.me().await,
data: m.clone(),
}),
),
None => return Err(CryptographicError::UnknownParticipant(to)),
}
}
Action::Return(private_share) => {
tracing::debug!("successfully completed key reshare");
tracing::debug!("resharing: successfully completed key reshare");

// Send any leftover messages.
if let Err(err) = self
.messages
.write()
.await
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(return): failed to send encrypted message");
@@ -270,7 +271,7 @@ impl CryptographicProtocol for RunningState {
let mut messages = self.messages.write().await;
// Try sending any leftover messages donated to RunningState.
if let Err(err) = messages
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, participants = ?self.participants, "running(pre): failed to send encrypted message");
@@ -298,7 +299,9 @@ impl CryptographicProtocol for RunningState {
&self.private_share,
)?;
} else {
tracing::debug!("we don't have enough triples to generate a presignature");
tracing::debug!(
"running(pre): we don't have enough triples to generate a presignature"
);
}
}
drop(triple_manager);
@@ -309,8 +312,8 @@ impl CryptographicProtocol for RunningState {

let mut sign_queue = self.sign_queue.write().await;
let mut signature_manager = self.signature_manager.write().await;
sign_queue.organize(&self, ctx.me());
let my_requests = sign_queue.my_requests(ctx.me());
sign_queue.organize(&self, ctx.me().await);
let my_requests = sign_queue.my_requests(ctx.me().await);
while presignature_manager.my_len() > 0 {
let Some((receipt_id, _)) = my_requests.iter().next() else {
break;
@@ -340,7 +343,7 @@ impl CryptographicProtocol for RunningState {
.await?;
drop(signature_manager);
if let Err(err) = messages
.send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client())
.send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client())
.await
{
tracing::warn!(?err, participants = ?self.participants, "running(post): failed to send encrypted message");
21 changes: 4 additions & 17 deletions node/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -15,8 +15,9 @@ use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::RwLock;

#[async_trait::async_trait]
pub trait MessageCtx {
fn me(&self) -> Participant;
async fn me(&self) -> Participant;
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
@@ -193,7 +194,6 @@ impl MessageHandler for ResharingState {
let q = queue.resharing_bins.entry(self.old_epoch).or_default();
let mut protocol = self.protocol.write().await;
while let Some(msg) = q.pop_front() {
tracing::debug!("handling new resharing message");
protocol.message(msg.from, msg.data);
}
Ok(())
@@ -210,9 +210,6 @@ impl MessageHandler for RunningState {
let mut triple_manager = self.triple_manager.write().await;
for (id, queue) in queue.triple_bins.entry(self.epoch).or_default() {
if let Some(protocol) = triple_manager.get_or_generate(*id)? {
let mut protocol = protocol
.write()
.map_err(|err| MessageHandleError::SyncError(err.to_string()))?;
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
}
@@ -231,12 +228,7 @@ impl MessageHandler for RunningState {
&self.public_key,
&self.private_share,
) {
Ok(protocol) => {
let mut protocol = protocol
.write()
.map_err(|err| MessageHandleError::SyncError(err.to_string()))?;
protocol.message(message.from, message.data)
}
Ok(protocol) => protocol.message(message.from, message.data),
Err(presignature::GenerationError::AlreadyGenerated) => {
tracing::info!(id, "presignature already generated, nothing left to do")
}
@@ -285,12 +277,7 @@ impl MessageHandler for RunningState {
message.delta,
&mut presignature_manager,
)? {
Some(protocol) => {
let mut protocol = protocol
.write()
.map_err(|err| MessageHandleError::SyncError(err.to_string()))?;
protocol.message(message.from, message.data)
}
Some(protocol) => protocol.message(message.from, message.data),
None => {
// Store the message until we are ready to process it
leftover_messages.push(message)
85 changes: 50 additions & 35 deletions node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod contract;
pub mod contract;
mod cryptography;
mod presignature;
mod signature;
@@ -9,7 +9,8 @@ pub mod message;
pub mod state;

pub use consensus::ConsensusError;
pub use contract::{ParticipantInfo, ProtocolState};
pub use contract::primitives::ParticipantInfo;
pub use contract::ProtocolState;
pub use cryptography::CryptographicError;
pub use message::MpcMessage;
pub use signature::SignQueue;
@@ -36,8 +37,8 @@ use url::Url;
use mpc_keys::hpke;

struct Ctx {
me: Participant,
my_address: Url,
account_id: AccountId,
mpc_contract_id: AccountId,
signer: InMemorySigner,
rpc_client: near_fetch::Client,
@@ -48,89 +49,91 @@ struct Ctx {
secret_storage: SecretNodeStorageBox,
}

impl ConsensusCtx for &Ctx {
fn me(&self) -> Participant {
self.me
impl ConsensusCtx for &MpcSignProtocol {
fn my_account_id(&self) -> &AccountId {
&self.ctx.account_id
}

fn http_client(&self) -> &reqwest::Client {
&self.http_client
&self.ctx.http_client
}

fn rpc_client(&self) -> &near_fetch::Client {
&self.rpc_client
&self.ctx.rpc_client
}

fn signer(&self) -> &InMemorySigner {
&self.signer
&self.ctx.signer
}

fn mpc_contract_id(&self) -> &AccountId {
&self.mpc_contract_id
&self.ctx.mpc_contract_id
}

fn my_address(&self) -> &Url {
&self.my_address
&self.ctx.my_address
}

fn sign_queue(&self) -> Arc<RwLock<SignQueue>> {
self.sign_queue.clone()
self.ctx.sign_queue.clone()
}

fn cipher_pk(&self) -> &hpke::PublicKey {
&self.cipher_pk
&self.ctx.cipher_pk
}

fn sign_pk(&self) -> near_crypto::PublicKey {
self.sign_sk.public_key()
self.ctx.sign_sk.public_key()
}

fn sign_sk(&self) -> &near_crypto::SecretKey {
&self.sign_sk
&self.ctx.sign_sk
}

fn secret_storage(&self) -> &SecretNodeStorageBox {
&self.secret_storage
&self.ctx.secret_storage
}
}

impl CryptographicCtx for &mut Ctx {
fn me(&self) -> Participant {
self.me
#[async_trait::async_trait]
impl CryptographicCtx for &mut MpcSignProtocol {
async fn me(&self) -> Participant {
get_my_participant(self).await
}

fn http_client(&self) -> &reqwest::Client {
&self.http_client
&self.ctx.http_client
}

fn rpc_client(&self) -> &near_fetch::Client {
&self.rpc_client
&self.ctx.rpc_client
}

fn signer(&self) -> &InMemorySigner {
&self.signer
&self.ctx.signer
}

fn mpc_contract_id(&self) -> &AccountId {
&self.mpc_contract_id
&self.ctx.mpc_contract_id
}

fn cipher_pk(&self) -> &hpke::PublicKey {
&self.cipher_pk
&self.ctx.cipher_pk
}

fn sign_sk(&self) -> &near_crypto::SecretKey {
&self.sign_sk
&self.ctx.sign_sk
}

fn secret_storage(&mut self) -> &mut SecretNodeStorageBox {
&mut self.secret_storage
&mut self.ctx.secret_storage
}
}

impl MessageCtx for &Ctx {
fn me(&self) -> Participant {
self.me
#[async_trait::async_trait]
impl MessageCtx for &MpcSignProtocol {
async fn me(&self) -> Participant {
get_my_participant(self).await
}
}

@@ -143,9 +146,9 @@ pub struct MpcSignProtocol {
impl MpcSignProtocol {
#![allow(clippy::too_many_arguments)]
pub fn init<U: IntoUrl>(
me: Participant,
my_address: U,
mpc_contract_id: AccountId,
account_id: AccountId,
rpc_client: near_fetch::Client,
signer: InMemorySigner,
receiver: mpsc::Receiver<MpcMessage>,
@@ -155,8 +158,8 @@ impl MpcSignProtocol {
) -> (Self, Arc<RwLock<NodeState>>) {
let state = Arc::new(RwLock::new(NodeState::Starting));
let ctx = Ctx {
me,
my_address: my_address.into_url().unwrap(),
account_id,
mpc_contract_id,
rpc_client,
http_client: reqwest::Client::new(),
@@ -175,7 +178,7 @@ impl MpcSignProtocol {
}

pub async fn run(mut self) -> anyhow::Result<()> {
let _span = tracing::info_span!("running", me = u32::from(self.ctx.me));
let _span = tracing::info_span!("running", my_account_id = self.ctx.account_id.to_string());
let mut queue = MpcMessageQueue::default();
loop {
tracing::debug!("trying to advance mpc recovery protocol");
@@ -215,21 +218,21 @@ impl MpcSignProtocol {
let guard = self.state.read().await;
guard.clone()
};
let state = match state.progress(&mut self.ctx).await {
let state = match state.progress(&mut self).await {
Ok(state) => state,
Err(err) => {
tracing::info!("protocol unable to progress: {err:?}");
continue;
}
};
let mut state = match state.advance(&self.ctx, contract_state).await {
let mut state = match state.advance(&self, contract_state).await {
Ok(state) => state,
Err(err) => {
tracing::info!("protocol unable to advance: {err:?}");
continue;
}
};
if let Err(err) = state.handle(&self.ctx, &mut queue).await {
if let Err(err) = state.handle(&self, &mut queue).await {
tracing::info!("protocol unable to handle messages: {err:?}");
continue;
}
@@ -242,3 +245,15 @@ impl MpcSignProtocol {
}
}
}

async fn get_my_participant(protocol: &MpcSignProtocol) -> Participant {
let my_near_acc_id = protocol.ctx.account_id.clone();
let state = protocol.state.read().await;
let participant_info = state
.find_participant_info(&my_near_acc_id)
.unwrap_or_else(|| {
tracing::error!("could not find participant info for {my_near_acc_id}");
panic!("could not find participant info for {my_near_acc_id}");
});
participant_info.id.into()
}
16 changes: 3 additions & 13 deletions node/src/protocol/presignature.rs
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@ use cait_sith::{KeygenOutput, PresignArguments, PresignOutput};
use k256::Secp256k1;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

/// Unique number used to identify a specific ongoing presignature generation protocol.
/// Without `PresignatureId` it would be unclear where to route incoming cait-sith presignature
@@ -99,7 +98,7 @@ impl PresignatureManager {
private_share: &SecretKeyShare,
mine: bool,
) -> Result<PresignatureGenerator, InitializationError> {
let protocol = Arc::new(std::sync::RwLock::new(cait_sith::presign(
let protocol = Box::new(cait_sith::presign(
participants,
me,
PresignArguments {
@@ -111,7 +110,7 @@ impl PresignatureManager {
},
threshold,
},
)?));
)?);
Ok(PresignatureGenerator {
protocol,
triple0: triple0.id,
@@ -213,16 +212,7 @@ impl PresignatureManager {
let mut result = Ok(());
self.generators.retain(|id, generator| {
loop {
let mut protocol = match generator.protocol.write() {
Ok(protocol) => protocol,
Err(err) => {
tracing::error!(
?err,
"failed to acquire lock on presignature generation protocol"
);
break false;
}
};
let protocol = &mut generator.protocol;
let action = match protocol.poke() {
Ok(action) => action,
Err(e) => {
16 changes: 3 additions & 13 deletions node/src/protocol/signature.rs
Original file line number Diff line number Diff line change
@@ -17,7 +17,6 @@ use rand::seq::{IteratorRandom, SliceRandom};
use rand::SeedableRng;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::Arc;

pub struct SignRequest {
pub receipt_id: CryptoHash,
@@ -145,13 +144,13 @@ impl SignatureManager {
k: k * delta.invert().unwrap(),
sigma: (sigma + epsilon * k) * delta.invert().unwrap(),
};
let protocol = Arc::new(std::sync::RwLock::new(cait_sith::sign(
let protocol = Box::new(cait_sith::sign(
participants,
me,
kdf::derive_key(public_key, epsilon),
output,
Scalar::from_bytes(&msg_hash),
)?));
)?);
Ok(SignatureGenerator {
protocol,
proposer,
@@ -237,16 +236,7 @@ impl SignatureManager {
let mut result = Ok(());
self.generators.retain(|receipt_id, generator| {
loop {
let mut protocol = match generator.protocol.write() {
Ok(protocol) => protocol,
Err(err) => {
tracing::error!(
?err,
"failed to acquire lock on signature generation protocol"
);
break false;
}
};
let protocol = &mut generator.protocol;
let action = match protocol.poke() {
Ok(action) => action,
Err(e) => {
35 changes: 26 additions & 9 deletions node/src/protocol/state.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::contract::primitives::{ParticipantInfo, Participants};
use super::cryptography::CryptographicError;
use super::presignature::PresignatureManager;
use super::signature::SignatureManager;
use super::triple::TripleManager;
use super::SignQueue;
use crate::http_client::MessageQueue;
use crate::protocol::ParticipantInfo;
use crate::types::{KeygenProtocol, PublicKey, ReshareProtocol, SecretKeyShare};
use cait_sith::protocol::Participant;
use near_primitives::types::AccountId;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::RwLock;

@@ -24,7 +24,7 @@ pub struct StartedState(pub Option<PersistentNodeData>);

#[derive(Clone)]
pub struct GeneratingState {
pub participants: BTreeMap<Participant, ParticipantInfo>,
pub participants: Participants,
pub threshold: usize,
pub protocol: KeygenProtocol,
pub messages: Arc<RwLock<MessageQueue>>,
@@ -42,7 +42,7 @@ impl GeneratingState {
#[derive(Clone)]
pub struct WaitingForConsensusState {
pub epoch: u64,
pub participants: BTreeMap<Participant, ParticipantInfo>,
pub participants: Participants,
pub threshold: usize,
pub private_share: SecretKeyShare,
pub public_key: PublicKey,
@@ -61,7 +61,7 @@ impl WaitingForConsensusState {
#[derive(Clone)]
pub struct RunningState {
pub epoch: u64,
pub participants: BTreeMap<Participant, ParticipantInfo>,
pub participants: Participants,
pub threshold: usize,
pub private_share: SecretKeyShare,
pub public_key: PublicKey,
@@ -84,8 +84,8 @@ impl RunningState {
#[derive(Clone)]
pub struct ResharingState {
pub old_epoch: u64,
pub old_participants: BTreeMap<Participant, ParticipantInfo>,
pub new_participants: BTreeMap<Participant, ParticipantInfo>,
pub old_participants: Participants,
pub new_participants: Participants,
pub threshold: usize,
pub public_key: PublicKey,
pub protocol: ReshareProtocol,
@@ -104,7 +104,7 @@ impl ResharingState {

#[derive(Clone)]
pub struct JoiningState {
pub participants: BTreeMap<Participant, ParticipantInfo>,
pub participants: Participants,
pub public_key: PublicKey,
}

@@ -144,11 +144,28 @@ impl NodeState {
_ => Err(CryptographicError::UnknownParticipant(*p)),
}
}

pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> {
match self {
NodeState::Starting => None,
NodeState::Started(_) => None,
NodeState::Generating(state) => state.participants.find_participant_info(account_id),
NodeState::WaitingForConsensus(state) => {
state.participants.find_participant_info(account_id)
}
NodeState::Running(state) => state.participants.find_participant_info(account_id),
NodeState::Resharing(state) => state
.new_participants
.find_participant_info(account_id)
.or_else(|| state.old_participants.find_participant_info(account_id)),
NodeState::Joining(state) => state.participants.find_participant_info(account_id),
}
}
}

fn fetch_participant<'a>(
p: &Participant,
participants: &'a BTreeMap<Participant, ParticipantInfo>,
participants: &'a Participants,
) -> Result<&'a ParticipantInfo, CryptographicError> {
participants
.get(p)
280 changes: 216 additions & 64 deletions node/src/protocol/triple.rs
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@ use crate::types::TripleProtocol;
use crate::util::AffinePointExt;
use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use cait_sith::triples::{TriplePub, TripleShare};
use highway::{HighwayHash, HighwayHasher};
use k256::elliptic_curve::group::GroupEncoding;
use k256::Secp256k1;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

/// Unique number used to identify a specific ongoing triple generation protocol.
/// Without `TripleId` it would be unclear where to route incoming cait-sith triple generation
@@ -21,28 +22,20 @@ pub struct Triple {
pub public: TriplePub<Secp256k1>,
}

/// An ongoing triple generator.
pub struct TripleGenerator {
/// Ongoing cait-sith triple generation protocol.
pub protocol: TripleProtocol,
/// Whether this triple generation was initiated by the current node.
pub mine: bool,
}

/// Abstracts how triples are generated by providing a way to request a new triple that will be
/// complete some time in the future and a way to take an already generated triple.
pub struct TripleManager {
/// Completed unspent triples
triples: HashMap<TripleId, Triple>,
pub triples: HashMap<TripleId, Triple>,
/// Ongoing triple generation protocols
generators: HashMap<TripleId, TripleGenerator>,
pub generators: HashMap<TripleId, TripleProtocol>,
/// List of triple ids generation of which was initiated by the current node.
mine: VecDeque<TripleId>,
pub mine: VecDeque<TripleId>,

participants: Vec<Participant>,
me: Participant,
threshold: usize,
epoch: u64,
pub participants: Vec<Participant>,
pub me: Participant,
pub threshold: usize,
pub epoch: u64,
}

impl TripleManager {
@@ -83,20 +76,12 @@ impl TripleManager {
pub fn generate(&mut self) -> Result<(), InitializationError> {
let id = rand::random();
tracing::debug!(id, "starting protocol to generate a new triple");
let protocol: TripleProtocol = Arc::new(std::sync::RwLock::new(
cait_sith::triples::generate_triple::<Secp256k1>(
&self.participants,
self.me,
self.threshold,
)?,
));
self.generators.insert(
id,
TripleGenerator {
protocol,
mine: true,
},
);
let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&self.participants,
self.me,
self.threshold,
)?);
self.generators.insert(id, protocol);
Ok(())
}

@@ -150,20 +135,15 @@ impl TripleManager {
match self.generators.entry(id) {
Entry::Vacant(e) => {
tracing::debug!(id, "joining protocol to generate a new triple");
let protocol = Arc::new(std::sync::RwLock::new(
cait_sith::triples::generate_triple::<Secp256k1>(
&self.participants,
self.me,
self.threshold,
)?,
));
let generator = e.insert(TripleGenerator {
protocol,
mine: false,
});
Ok(Some(&mut generator.protocol))
let protocol = Box::new(cait_sith::triples::generate_triple::<Secp256k1>(
&self.participants,
self.me,
self.threshold,
)?);
let generator = e.insert(protocol);
Ok(Some(generator))
}
Entry::Occupied(e) => Ok(Some(&mut e.into_mut().protocol)),
Entry::Occupied(e) => Ok(Some(e.into_mut())),
}
}
}
@@ -175,19 +155,8 @@ impl TripleManager {
pub fn poke(&mut self) -> Result<Vec<(Participant, TripleMessage)>, ProtocolError> {
let mut messages = Vec::new();
let mut result = Ok(());
self.generators.retain(|id, generator| {
self.generators.retain(|id, protocol| {
loop {
let mut protocol = match generator.protocol.write() {
Ok(protocol) => protocol,
Err(err) => {
tracing::error!(
?err,
"failed to acquire lock on triple generation protocol"
);
break false;
}
};

let action = match protocol.poke() {
Ok(action) => action,
Err(e) => {
@@ -232,18 +201,38 @@ impl TripleManager {
big_c = ?output.1.big_c.to_base58(),
"completed triple generation"
);
self.triples.insert(
*id,
Triple {
id: *id,
share: output.0,
public: output.1,
},
);

if generator.mine {
let triple = Triple {
id: *id,
share: output.0,
public: output.1,
};

// After creation the triple is assigned to a random node, which is NOT necessarily the one that initiated it's creation
let triple_is_mine = {
// This is an entirely unpredictable value to all participants because it's a combination of big_c_i
// It is the same value across all participants
let big_c = triple.public.big_c;

// We turn this into a u64 in a way not biased to the structure of the byte serialisation so we hash it
// We use Highway Hash because the DefaultHasher doesn't guarantee a consistent output across versions
let entropy =
HighwayHasher::default().hash64(&big_c.to_bytes()) as usize;

let num_participants = self.participants.len();
// This has a *tiny* bias towards lower indexed participants, they're up to (1 + num_participants / u64::MAX)^2 times more likely to be selected
// This is acceptably small that it will likely never result in a biased selection happening
let triple_owner = self.participants[entropy % num_participants];

triple_owner == self.me
};

if triple_is_mine {
self.mine.push_back(*id);
}

self.triples.insert(*id, triple);

// Do not retain the protocol
break false;
}
@@ -253,3 +242,166 @@ impl TripleManager {
result.map(|_| messages)
}
}

#[cfg(test)]
mod test {
use std::{collections::HashMap, fs::OpenOptions, ops::Range};

use crate::protocol::message::TripleMessage;
use cait_sith::protocol::{InitializationError, Participant, ProtocolError};
use itertools::multiunzip;
use std::io::prelude::*;

use super::TripleManager;

struct TestManagers {
managers: Vec<TripleManager>,
}

impl TestManagers {
fn new(number: u32) -> Self {
let range = 0..number;
// Self::wipe_mailboxes(range.clone());
let participants: Vec<Participant> = range.map(Participant::from).collect();
let managers = participants
.iter()
.map(|me| TripleManager::new(participants.clone(), *me, number as usize, 0))
.collect();
TestManagers { managers }
}

fn generate(&mut self, index: usize) -> Result<(), InitializationError> {
self.managers[index].generate()
}

fn poke(&mut self, index: usize) -> Result<bool, ProtocolError> {
let mut quiet = true;
let messages = self.managers[index].poke()?;
for (
participant,
ref tm @ TripleMessage {
id, from, ref data, ..
},
) in messages
{
// Self::debug_mailbox(participant.into(), &tm);
quiet = false;
let participant_i: u32 = participant.into();
let manager = &mut self.managers[participant_i as usize];
if let Some(protocol) = manager.get_or_generate(id).unwrap() {
protocol.message(from, data.to_vec());
} else {
println!("Tried to write to completed mailbox {:?}", tm);
}
}
Ok(quiet)
}

#[allow(unused)]
fn wipe_mailboxes(mailboxes: Range<u32>) {
for m in mailboxes {
let mut file = OpenOptions::new()
.write(true)
.append(false)
.create(true)
.open(format!("{}.csv", m))
.unwrap();
write!(file, "").unwrap();
}
}

// This allows you to see what each node is recieving and when
#[allow(unused)]
fn debug_mailbox(participant: u32, TripleMessage { id, from, data, .. }: &TripleMessage) {
let mut file = OpenOptions::new()
.write(true)
.append(true)
.open(format!("{}.csv", participant))
.unwrap();

writeln!(file, "'{id}, {from:?}, {}", hex::encode(data)).unwrap();
}

fn poke_until_quiet(&mut self) -> Result<(), ProtocolError> {
loop {
let mut quiet = true;
for i in 0..self.managers.len() {
let poke = self.poke(i)?;
quiet = quiet && poke;
}
if quiet {
return Ok(());
}
}
}
}

// TODO: This test currently takes 22 seconds on my machine, which is much slower than it should be
// Improve this before we make more similar tests
#[test]
fn happy_triple_generation() {
let mut tm = TestManagers::new(5);

const M: usize = 2;
const N: usize = M + 3;
// Generate 5 triples
for _ in 0..M {
tm.generate(0).unwrap();
}
tm.poke_until_quiet().unwrap();
tm.generate(1).unwrap();
tm.generate(2).unwrap();
tm.generate(4).unwrap();

tm.poke_until_quiet().unwrap();

let inputs = tm
.managers
.into_iter()
.map(|m| (m.my_len(), m.len(), m.generators, m.triples));

let (my_lens, lens, generators, mut triples): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
multiunzip(inputs);

assert_eq!(
my_lens.iter().sum::<usize>(),
N,
"There should be {N} owned completed triples in total",
);

for l in lens {
assert_eq!(l, N, "All nodes should have {N} completed triples")
}

// This passes, but we don't have deterministic entropy or enough triples
// to ensure that it will no coincidentally fail
// TODO: deterministic entropy for testing
// assert_ne!(
// my_lens,
// vec![M, 1, 1, 0, 1],
// "The nodes that started the triple don't own it"
// );

for g in generators.iter() {
assert!(g.is_empty(), "There are no triples still being generated")
}

assert_ne!(
triples.len(),
1,
"The number of triples is not 1 before deduping"
);

triples.dedup_by_key(|kv| {
kv.iter_mut()
.map(|(id, triple)| (*id, (triple.id, triple.public.clone())))
.collect::<HashMap<_, _>>()
});

assert_eq!(
triples.len(),
1,
"All triple IDs and public parts are identical"
)
}
}
8 changes: 3 additions & 5 deletions node/src/types.rs
Original file line number Diff line number Diff line change
@@ -11,8 +11,6 @@ pub type PublicKey = <Secp256k1 as CurveArithmetic>::AffinePoint;
pub type KeygenProtocol = Arc<RwLock<dyn Protocol<Output = KeygenOutput<Secp256k1>> + Send + Sync>>;
pub type ReshareProtocol = Arc<RwLock<dyn Protocol<Output = SecretKeyShare> + Send + Sync>>;
pub type TripleProtocol =
Arc<std::sync::RwLock<dyn Protocol<Output = TripleGenerationOutput<Secp256k1>> + Send + Sync>>;
pub type PresignatureProtocol =
Arc<std::sync::RwLock<dyn Protocol<Output = PresignOutput<Secp256k1>> + Send + Sync>>;
pub type SignatureProtocol =
Arc<std::sync::RwLock<dyn Protocol<Output = FullSignature<Secp256k1>> + Send + Sync>>;
Box<dyn Protocol<Output = TripleGenerationOutput<Secp256k1>> + Send + Sync>;
pub type PresignatureProtocol = Box<dyn Protocol<Output = PresignOutput<Secp256k1>> + Send + Sync>;
pub type SignatureProtocol = Box<dyn Protocol<Output = FullSignature<Secp256k1>> + Send + Sync>;
2 changes: 1 addition & 1 deletion node/src/web/error.rs
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ impl axum::response::IntoResponse for Error {
let status = self.status();
let message = match self {
Error::JsonExtractorRejection(json_rejection) => json_rejection.body_text(),
err => err.to_string(),
err => format!("{err:?}"),
};

(status, axum::Json(message)).into_response()
8 changes: 4 additions & 4 deletions node/src/web/mod.rs
Original file line number Diff line number Diff line change
@@ -100,13 +100,13 @@ async fn msg(
#[tracing::instrument(level = "debug", skip_all)]
async fn join(
Extension(state): Extension<Arc<AxumState>>,
WithRejection(Json(participant), _): WithRejection<Json<Participant>, Error>,
WithRejection(Json(account_id), _): WithRejection<Json<AccountId>, Error>,
) -> Result<()> {
let protocol_state = state.protocol_state.read().await;
match &*protocol_state {
NodeState::Running { .. } => {
let args = serde_json::json!({
"participant": participant
"candidate_account_id": account_id
});
match state
.rpc_client
@@ -123,7 +123,7 @@ async fn join(
.await
{
Ok(_) => {
tracing::info!(?participant, "successfully voted for a node to join");
tracing::info!(?account_id, "successfully voted for a node to join");
Ok(())
}
Err(e) => {
@@ -133,7 +133,7 @@ async fn join(
}
}
_ => {
tracing::debug!(?participant, "not ready to accept join requests yet");
tracing::debug!(?account_id, "not ready to accept join requests yet");
Err(Error::NotRunning)
}
}

0 comments on commit 82ba4fd

Please sign in to comment.