From 13d571f3a7ca4358306eceb820b7cb2b3cddd2be Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Thu, 26 Oct 2023 17:47:21 -0400 Subject: [PATCH] Fix linter issues --- common/Cargo.toml | 8 +- common/datagen/datagen.rs | 7 +- common/src/files.rs | 4 +- common/src/s3_path.rs | 19 +- protocol-rpc/src/connect/create_client.rs | 40 +-- protocol-rpc/src/proto/mod.rs | 8 +- protocol-rpc/src/rpc/dpmc/client.rs | 162 ++++++----- protocol-rpc/src/rpc/dpmc/company-server.rs | 29 +- protocol-rpc/src/rpc/dpmc/partner-server.rs | 37 ++- .../src/rpc/dpmc/rpc_client_company.rs | 21 +- .../src/rpc/dpmc/rpc_client_partner.rs | 21 +- .../src/rpc/dpmc/rpc_server_company.rs | 84 +++--- .../src/rpc/dpmc/rpc_server_partner.rs | 75 +++-- protocol-rpc/src/rpc/dspmc/client.rs | 263 +++++++++--------- protocol-rpc/src/rpc/dspmc/company-server.rs | 47 ++-- protocol-rpc/src/rpc/dspmc/helper-server.rs | 36 ++- protocol-rpc/src/rpc/dspmc/partner-server.rs | 52 ++-- .../src/rpc/dspmc/rpc_client_company.rs | 22 +- .../src/rpc/dspmc/rpc_client_helper.rs | 27 +- .../src/rpc/dspmc/rpc_client_partner.rs | 21 +- .../src/rpc/dspmc/rpc_server_company.rs | 154 ++++++---- .../src/rpc/dspmc/rpc_server_helper.rs | 120 ++++---- .../src/rpc/dspmc/rpc_server_partner.rs | 68 +++-- protocol/src/dpmc/company.rs | 97 ++++--- protocol/src/dpmc/helper.rs | 221 ++++++++------- protocol/src/dpmc/mod.rs | 22 +- protocol/src/dpmc/partner.rs | 96 ++++--- protocol/src/dpmc/traits.rs | 21 +- protocol/src/dspmc/company.rs | 226 +++++++-------- protocol/src/dspmc/helper.rs | 242 ++++++++-------- protocol/src/dspmc/mod.rs | 24 +- protocol/src/dspmc/partner.rs | 77 ++--- protocol/src/dspmc/shuffler.rs | 173 +++++------- protocol/src/dspmc/traits.rs | 19 +- protocol/src/lib.rs | 4 +- 35 files changed, 1350 insertions(+), 1197 deletions(-) diff --git a/common/Cargo.toml b/common/Cargo.toml index f0f782e..87b8e10 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -23,10 +23,10 @@ hex = "0.3.0" serde = {version = "1.0.104", features = ["derive"] } num = "0.2.1" wasm-timer = "0.2.5" -aws-config = "0.54.1" -aws-credential-types = "0.54.1" -aws-sdk-s3 = "0.24.0" -aws-smithy-http = "0.54.0" +aws-config = "0.56.1" +aws-credential-types = "0.56.1" +aws-sdk-s3 = "0.34.0" +aws-smithy-http = "0.56.0" lazy_static = "1.4.0" regex = "1.5.4" tempfile = "3.2.0" diff --git a/common/datagen/datagen.rs b/common/datagen/datagen.rs index a50c933..08ed254 100644 --- a/common/datagen/datagen.rs +++ b/common/datagen/datagen.rs @@ -169,14 +169,17 @@ fn main() { let fn_a = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "a", size, cols); let fn_b = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "b", size, cols); - let fn_b_features = format!("{}/input_{}_size_{}_cols_{}_features.csv", dir, "b", size, cols); + let fn_b_features = format!( + "{}/input_{}_size_{}_cols_{}_features.csv", + dir, "b", size, cols + ); info!("Generating output of size {}", size); info!("Player a output: {}", fn_a); info!("Player b output: {}", fn_b); info!("Player b features: {}", fn_b_features); - let intrsct = size / 2 as usize; + let intrsct = size / 2_usize; let size_player = size - intrsct; let data = gen::random_data(size_player, size_player, intrsct); info!("Data generation done, writing to files"); diff --git a/common/src/files.rs b/common/src/files.rs index e4bfb15..8709fae 100644 --- a/common/src/files.rs +++ b/common/src/files.rs @@ -106,7 +106,9 @@ where it.map(|x| { x.unwrap() .iter() - .map(|z| u64::from_str(z.trim()).unwrap_or_else(|_| panic!("Cannot format {} as u64", z))) + .map(|z| { + u64::from_str(z.trim()).unwrap_or_else(|_| panic!("Cannot format {} as u64", z)) + }) .collect::>() }) .collect::>>() diff --git a/common/src/s3_path.rs b/common/src/s3_path.rs index 29fc530..615d75d 100644 --- a/common/src/s3_path.rs +++ b/common/src/s3_path.rs @@ -8,11 +8,6 @@ use std::path::Path; use std::str::FromStr; use std::time::Duration; -use aws_sdk_s3::Region; -use aws_sdk_s3::error::NoSuchUpload; -use aws_sdk_s3::model::CompletedPart; -use aws_sdk_s3::model::CompletedMultipartUpload; -use aws_sdk_s3::types::ByteStream; use aws_config::default_provider::credentials::default_provider; use aws_credential_types::cache::CredentialsCache; use regex::Regex; @@ -54,7 +49,7 @@ impl S3Path { pub async fn copy_to_local(&self) -> Result { let default_provider = default_provider().await; - let region = Region::new(self.get_region().clone()); + let region = aws_sdk_s3::config::Region::new(self.get_region().clone()); let aws_cfg = aws_config::from_env() .credentials_cache( CredentialsCache::lazy_builder() @@ -101,7 +96,7 @@ impl S3Path { pub async fn copy_from_local(&self, path: impl AsRef) -> Result<(), aws_sdk_s3::Error> { let default_provider = default_provider().await; - let region = Region::new(self.get_region().clone()); + let region = aws_sdk_s3::config::Region::new(self.get_region().clone()); let aws_cfg = aws_config::from_env() .region(region) .credentials_cache( @@ -136,12 +131,12 @@ impl S3Path { .unwrap(); let uid = u.upload_id().ok_or_else(|| { aws_sdk_s3::Error::NoSuchUpload( - NoSuchUpload::builder() + aws_sdk_s3::types::error::NoSuchUpload::builder() .message("No upload ID") .build(), ) })?; - let mut completed_parts: Vec = Vec::new(); + let mut completed_parts: Vec = Vec::new(); for i in 0..chunks { let length = if i == chunks - 1 { // If we're on the last chunk, the length to read might be less than a whole chunk. @@ -151,7 +146,7 @@ impl S3Path { } else { chunk_size }; - let byte_stream = ByteStream::read_from() + let byte_stream = aws_sdk_s3::primitives::ByteStream::read_from() .path(path.as_ref()) .offset(i * chunk_size) .length(aws_smithy_http::byte_stream::Length::Exact(length)) @@ -167,14 +162,14 @@ impl S3Path { .send() .await .unwrap(); - let cp = CompletedPart::builder() + let cp = aws_sdk_s3::types::CompletedPart::builder() .set_e_tag(upload.e_tag) .part_number((i + 1) as i32) .build(); completed_parts.push(cp); } // Complete multipart upload, sending the (etag, part id) list along the request. - let b = CompletedMultipartUpload::builder() + let b = aws_sdk_s3::types::CompletedMultipartUpload::builder() .set_parts(Some(completed_parts)) .build(); let completed = client diff --git a/protocol-rpc/src/connect/create_client.rs b/protocol-rpc/src/connect/create_client.rs index 612831c..92b9887 100644 --- a/protocol-rpc/src/connect/create_client.rs +++ b/protocol-rpc/src/connect/create_client.rs @@ -16,15 +16,15 @@ use tonic::transport::Endpoint; use crate::connect::tls; use crate::proto::gen_crosspsi::cross_psi_client::CrossPsiClient; use crate::proto::gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient; -use crate::proto::gen_pjc::pjc_client::PjcClient; -use crate::proto::gen_private_id::private_id_client::PrivateIdClient; -use crate::proto::gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; -use crate::proto::gen_suid_create::suid_create_client::SuidCreateClient; use crate::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; use crate::proto::gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; use crate::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; use crate::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; use crate::proto::gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; +use crate::proto::gen_pjc::pjc_client::PjcClient; +use crate::proto::gen_private_id::private_id_client::PrivateIdClient; +use crate::proto::gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; +use crate::proto::gen_suid_create::suid_create_client::SuidCreateClient; use crate::proto::RpcClient; pub fn create_client( @@ -145,21 +145,11 @@ pub fn create_client( "cross-psi-xor" => RpcClient::CrossPsiXor(CrossPsiXorClient::new(conn)), "pjc" => RpcClient::Pjc(PjcClient::new(conn)), "suid-create" => RpcClient::SuidCreate(SuidCreateClient::new(conn)), - "dpmc-company" => RpcClient::DpmcCompany( - DpmcCompanyClient::new(conn), - ), - "dpmc-partner" => RpcClient::DpmcPartner( - DpmcPartnerClient::new(conn), - ), - "dspmc-company" => RpcClient::DspmcCompany( - DspmcCompanyClient::new(conn), - ), - "dspmc-helper" => RpcClient::DspmcHelper( - DspmcHelperClient::new(conn), - ), - "dspmc-partner" => RpcClient::DspmcPartner( - DspmcPartnerClient::new(conn), - ), + "dpmc-company" => RpcClient::DpmcCompany(DpmcCompanyClient::new(conn)), + "dpmc-partner" => RpcClient::DpmcPartner(DpmcPartnerClient::new(conn)), + "dspmc-company" => RpcClient::DspmcCompany(DspmcCompanyClient::new(conn)), + "dspmc-helper" => RpcClient::DspmcHelper(DspmcHelperClient::new(conn)), + "dspmc-partner" => RpcClient::DspmcPartner(DspmcPartnerClient::new(conn)), _ => panic!("wrong client"), }) } else { @@ -187,19 +177,13 @@ pub fn create_client( DpmcPartnerClient::connect(__uri).await.unwrap(), )), "dspmc-company" => Ok(RpcClient::DspmcCompany( - DspmcCompanyClient::connect(__uri) - .await - .unwrap(), + DspmcCompanyClient::connect(__uri).await.unwrap(), )), "dspmc-helper" => Ok(RpcClient::DspmcHelper( - DspmcHelperClient::connect(__uri) - .await - .unwrap(), + DspmcHelperClient::connect(__uri).await.unwrap(), )), "dspmc-partner" => Ok(RpcClient::DspmcPartner( - DspmcPartnerClient::connect(__uri) - .await - .unwrap(), + DspmcPartnerClient::connect(__uri).await.unwrap(), )), _ => panic!("wrong client"), } diff --git a/protocol-rpc/src/proto/mod.rs b/protocol-rpc/src/proto/mod.rs index 37161dc..46d7256 100644 --- a/protocol-rpc/src/proto/mod.rs +++ b/protocol-rpc/src/proto/mod.rs @@ -53,15 +53,15 @@ pub mod streaming; use gen_crosspsi::cross_psi_client::CrossPsiClient; use gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient; -use gen_pjc::pjc_client::PjcClient; -use gen_private_id::private_id_client::PrivateIdClient; -use gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; -use gen_suid_create::suid_create_client::SuidCreateClient; use gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; use gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; use gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; use gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; use gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; +use gen_pjc::pjc_client::PjcClient; +use gen_private_id::private_id_client::PrivateIdClient; +use gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; +use gen_suid_create::suid_create_client::SuidCreateClient; use tonic::transport::Channel; pub enum RpcClient { PrivateId(PrivateIdClient), diff --git a/protocol-rpc/src/rpc/dpmc/client.rs b/protocol-rpc/src/rpc/dpmc/client.rs index c3c696f..cdb4931 100644 --- a/protocol-rpc/src/rpc/dpmc/client.rs +++ b/protocol-rpc/src/rpc/dpmc/client.rs @@ -1,29 +1,26 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -use clap::{App, Arg, ArgGroup}; -use log::{error, info}; use std::convert::TryInto; -use tonic::Request; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use common::timer; use crypto::prelude::TPayload; -use protocol::dpmc::{helper::HelperDpmc, traits::*}; -use rpc::{ - connect::create_client::create_client, - proto::{ - gen_dpmc_company::{ - service_response::Ack as CompanyAck, - Init as CompanyInit, - ServiceResponse as CompanyServiceResponse - }, - gen_dpmc_partner::{ - service_response::Ack as PartnerAck, - Init as PartnerInit, - SendData as PartnerSendData, - }, - RpcClient, - }, -}; +use log::error; +use log::info; +use protocol::dpmc::helper::HelperDpmc; +use protocol::dpmc::traits::*; +use rpc::connect::create_client::create_client; +use rpc::proto::gen_dpmc_company::service_response::Ack as CompanyAck; +use rpc::proto::gen_dpmc_company::Init as CompanyInit; +use rpc::proto::gen_dpmc_company::ServiceResponse as CompanyServiceResponse; +use rpc::proto::gen_dpmc_partner::service_response::Ack as PartnerAck; +use rpc::proto::gen_dpmc_partner::Init as PartnerInit; +use rpc::proto::gen_dpmc_partner::SendData as PartnerSendData; +use rpc::proto::RpcClient; +use tonic::Request; mod rpc_client_company; mod rpc_client_partner; @@ -62,13 +59,16 @@ async fn main() -> Result<(), Box> { Arg::with_name("output-shares-path") .long("output-shares-path") .takes_value(true) - .help("path to write shares of features.\n - Feature will be written as {path}_partner_features.csv"), + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), Arg::with_name("one-to-many") .long("one-to-many") .takes_value(true) .required(false) - .help("By default, DPMC generates one-to-one matches. Use this\n + .help( + "By default, DPMC generates one-to-one matches. Use this\n flag to generate one(C)-to-many(P) matches.", ), Arg::with_name("no-tls") @@ -226,34 +226,34 @@ async fn main() -> Result<(), Box> { for i in 0..partner_client_context.len() { // Send company public key - let _ = - match rpc_client_partner::send( - company_public_key.clone(), - "company_public_key".to_string(), - &mut partner_client_context[i]) - .await? - .into_inner() - .ack - .unwrap() - { - PartnerAck::CompanyPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_partner::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; // Send helper public key - let _ = - match rpc_client_partner::send( - helper_public_key.clone(), - "helper_public_key".to_string(), - &mut partner_client_context[i]) - .await? - .into_inner() - .ack - .unwrap() - { - PartnerAck::HelperPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_partner::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; } } @@ -273,11 +273,23 @@ async fn main() -> Result<(), Box> { .await?; let offset_len = u64::from_le_bytes( - h_company_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(), + h_company_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), ) as usize; // flattened len let data_len = u64::from_le_bytes( - h_company_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(), + h_company_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), ) as usize; let offset = h_company_beta @@ -331,7 +343,13 @@ async fn main() -> Result<(), Box> { .await?; let xor_shares_len = u64::from_le_bytes( - h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap() + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), ) as usize; let xor_shares = h_partner_alpha_beta @@ -346,11 +364,23 @@ async fn main() -> Result<(), Box> { // deserialize ragged array let num_partner_keys = u64::from_le_bytes( - h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(), + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), ) as usize; // flattened len let data_len = u64::from_le_bytes( - h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(), + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), ) as usize; let offset = h_partner_alpha_beta @@ -364,7 +394,11 @@ async fn main() -> Result<(), Box> { // Perform 1/alpha, where alpha = partner.alpha. // Then decrypt XOR secret shares and compute features and mask. helper_protocol.remove_partner_scalar_from_p_and_set_shares( - h_partner_alpha_beta, offset, enc_alpha_t.buffer, vec![p_scalar_times_g], xor_shares + h_partner_alpha_beta, + offset, + enc_alpha_t.buffer, + vec![p_scalar_times_g], + xor_shares, )?; } @@ -385,14 +419,12 @@ async fn main() -> Result<(), Box> { let v_d_prime = helper_protocol.calculate_features_xor_shares()?; // 13. Set XOR share of features for company - let _ = rpc_client_company::calculate_features_xor_shares( - v_d_prime, - &mut company_client_context, - ) - .await? - .into_inner() - .ack - .unwrap(); + let _ = + rpc_client_company::calculate_features_xor_shares(v_d_prime, &mut company_client_context) + .await? + .into_inner() + .ack + .unwrap(); // 14. Print Company's ID spine and save partners shares rpc_client_company::reveal(&mut company_client_context).await?; @@ -405,7 +437,9 @@ async fn main() -> Result<(), Box> { // 16. Print Helper's feature shares match output_shares_path { - Some(p) => helper_protocol.save_features_shares(&String::from(p)).unwrap(), + Some(p) => helper_protocol + .save_features_shares(&String::from(p)) + .unwrap(), None => error!("Output features path not set. Can't output shares"), }; diff --git a/protocol-rpc/src/rpc/dpmc/company-server.rs b/protocol-rpc/src/rpc/dpmc/company-server.rs index 10c8227..ef457f6 100644 --- a/protocol-rpc/src/rpc/dpmc/company-server.rs +++ b/protocol-rpc/src/rpc/dpmc/company-server.rs @@ -7,15 +7,20 @@ extern crate clap; extern crate ctrlc; extern crate tonic; -use clap::{App, Arg, ArgGroup}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use log::info; -use std::{sync::{atomic::{AtomicBool, Ordering}, Arc}, thread, time,}; mod rpc_server_company; -use rpc::{ - connect::create_server::create_server, - proto::gen_dpmc_company::dpmc_company_server, -}; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dpmc_company::dpmc_company_server; #[tokio::main] async fn main() -> Result<(), Box> { @@ -53,8 +58,10 @@ async fn main() -> Result<(), Box> { .long("output-shares-path") .takes_value(true) .required(true) - .help("path to write shares of features.\n - Feature will be written as {path}_partner_features.csv"), + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), Arg::with_name("no-tls") .long("no-tls") .takes_value(false) @@ -139,7 +146,7 @@ async fn main() -> Result<(), Box> { input_path, output_keys_path, output_shares_path, - input_with_headers + input_with_headers, ); let ks = service.killswitch.clone(); @@ -158,9 +165,7 @@ async fn main() -> Result<(), Box> { let addr = host.unwrap().parse()?; server - .add_service(dpmc_company_server::DpmcCompanyServer::new( - service, - )) + .add_service(dpmc_company_server::DpmcCompanyServer::new(service)) .serve_with_shutdown(addr, async { rx.await.ok(); }) diff --git a/protocol-rpc/src/rpc/dpmc/partner-server.rs b/protocol-rpc/src/rpc/dpmc/partner-server.rs index 85abe49..d434057 100644 --- a/protocol-rpc/src/rpc/dpmc/partner-server.rs +++ b/protocol-rpc/src/rpc/dpmc/partner-server.rs @@ -7,16 +7,23 @@ extern crate clap; extern crate ctrlc; extern crate tonic; -use clap::{App, Arg, ArgGroup}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use log::info; -use std::{sync::{ atomic::{AtomicBool, Ordering}, Arc, }, thread, time}; -mod rpc_server_partner; mod rpc_client_company; -use rpc::{ - connect::{create_server::create_server, create_client::create_client,}, - proto::{gen_dpmc_partner::dpmc_partner_server, RpcClient}, -}; +mod rpc_server_partner; +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dpmc_partner::dpmc_partner_server; +use rpc::proto::RpcClient; #[tokio::main] async fn main() -> Result<(), Box> { @@ -87,15 +94,15 @@ async fn main() -> Result<(), Box> { .takes_value(true) .help("Override TLS domain for SSL cert (if host is IP)"), ]) - .groups(&[ - ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) - .required(true), - ]) + .groups(&[ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true)]) .get_matches(); let input_keys_path = matches.value_of("input-keys").unwrap_or("input_keys.csv"); - let input_features_path = matches.value_of("input-features").unwrap_or("input_features.csv"); + let input_features_path = matches + .value_of("input-features") + .unwrap_or("input_features.csv"); let input_with_headers = matches.is_present("input-with-headers"); let no_tls = matches.is_present("no-tls"); @@ -139,7 +146,7 @@ async fn main() -> Result<(), Box> { input_keys_path, input_features_path, input_with_headers, - company_client_context + company_client_context, ); let ks = service.killswitch.clone(); @@ -158,7 +165,7 @@ async fn main() -> Result<(), Box> { let addr = host.unwrap().parse()?; server - .add_service(dpmc_partner_server::DpmcPartnerServer::new(service,)) + .add_service(dpmc_partner_server::DpmcPartnerServer::new(service)) .serve_with_shutdown(addr, async { rx.await.ok(); }) diff --git a/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs b/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs index 239ea25..d26ac22 100644 --- a/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs +++ b/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs @@ -5,15 +5,17 @@ extern crate common; extern crate crypto; extern crate protocol; -use tonic::{transport::Channel, Request, Response, Status}; use common::timer; use crypto::prelude::TPayload; -use rpc::proto::{ - gen_dpmc_company::{ - dpmc_company_client::DpmcCompanyClient, Commitment, ServiceResponse, - }, - streaming::{read_from_stream, send_data}, -}; +use rpc::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use rpc::proto::gen_dpmc_company::Commitment; +use rpc::proto::gen_dpmc_company::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; pub async fn recv( response: ServiceResponse, @@ -46,7 +48,10 @@ pub async fn calculate_features_xor_shares( } pub async fn calculate_id_map(rpc: &mut DpmcCompanyClient) -> Result<(), Status> { - let _r = rpc.calculate_id_map(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } diff --git a/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs b/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs index a1edd30..10755c8 100644 --- a/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs +++ b/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs @@ -5,15 +5,15 @@ extern crate common; extern crate crypto; extern crate protocol; -use tonic::{transport::Channel, Request, Response, Status}; - use crypto::prelude::TPayload; -use rpc::proto::{ - gen_dpmc_partner::{ - dpmc_partner_client::DpmcPartnerClient, Commitment, ServiceResponse, - }, - streaming::send_data, -}; +use rpc::proto::gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; +use rpc::proto::gen_dpmc_partner::Commitment; +use rpc::proto::gen_dpmc_partner::ServiceResponse; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; pub async fn send( data: TPayload, @@ -28,6 +28,9 @@ pub async fn send( } pub async fn stop_service(rpc: &mut DpmcPartnerClient) -> Result<(), Status> { - let _r = rpc.stop_service(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } diff --git a/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs b/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs index 22106d7..502ad68 100644 --- a/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs +++ b/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs @@ -1,26 +1,34 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -use std::{ - borrow::BorrowMut, - convert::TryInto, - sync::{atomic::{AtomicBool, Ordering}, Arc}, -}; -use tonic::{Code, Request, Response, Status, Streaming}; +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + use common::timer; -use protocol::{ - dpmc::{company::CompanyDpmc, traits::CompanyDpmcProtocol}, - shared::TFeatures, -}; -use rpc::proto::{ - common::Payload, - gen_dpmc_company::{ - dpmc_company_server::DpmcCompany, service_response::*, - Commitment, CommitmentAck, Init, InitAck, ServiceResponse, - UPartnerAck, CalculateFeaturesXorSharesAck, - }, - streaming::{read_from_stream, write_to_stream, TPayloadStream}, -}; +use protocol::dpmc::company::CompanyDpmc; +use protocol::dpmc::traits::CompanyDpmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dpmc_company::dpmc_company_server::DpmcCompany; +use rpc::proto::gen_dpmc_company::service_response::*; +use rpc::proto::gen_dpmc_company::CalculateFeaturesXorSharesAck; +use rpc::proto::gen_dpmc_company::Commitment; +use rpc::proto::gen_dpmc_company::CommitmentAck; +use rpc::proto::gen_dpmc_company::Init; +use rpc::proto::gen_dpmc_company::InitAck; +use rpc::proto::gen_dpmc_company::ServiceResponse; +use rpc::proto::gen_dpmc_company::UPartnerAck; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; pub struct DpmcCompanyService { protocol: CompanyDpmc, @@ -67,16 +75,17 @@ impl DpmcCompany for DpmcCompanyService { })) } - async fn calculate_id_map(&self, _: Request) -> Result, Status> { + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("calculate_id_map") .build(); self.protocol .write_company_to_id_map() - .map(|_| { - Response::new(CommitmentAck {}) - }) + .map(|_| Response::new(CommitmentAck {})) .map_err(|_| Status::new(Code::Aborted, "cannot init the protocol for partner")) } @@ -94,7 +103,10 @@ impl DpmcCompany for DpmcCompanyService { u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_rows = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let mask = data.drain(num_features * num_rows..).map(|x| x).collect::>(); + let mask = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); let mut t = TFeatures::new(); for i in (0..num_features).rev() { @@ -109,13 +121,18 @@ impl DpmcCompany for DpmcCompanyService { .calculate_features_xor_shares(t, mask) .map(|_| { Response::new(ServiceResponse { - ack: Some(Ack::CalculateFeaturesXorSharesAck(CalculateFeaturesXorSharesAck {})), + ack: Some(Ack::CalculateFeaturesXorSharesAck( + CalculateFeaturesXorSharesAck {}, + )), }) }) .map_err(|_| Status::internal("error calculating XOR shares")) } - async fn recv_company_public_key(&self, _: Request) -> Result, Status> { + async fn recv_company_public_key( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("recv_company_public_key") @@ -189,7 +206,13 @@ impl DpmcCompany for DpmcCompanyService { assert_eq!(offset_len, offset.len()); self.protocol - .set_encrypted_partner_keys_and_shares(data, offset, enc_alpha_t.buffer, p_scalar_g.buffer, xor_shares) + .set_encrypted_partner_keys_and_shares( + data, + offset, + enc_alpha_t.buffer, + p_scalar_g.buffer, + xor_shares, + ) .map(|_| { Response::new(ServiceResponse { ack: Some(Ack::UPartnerAck(UPartnerAck {})), @@ -208,11 +231,10 @@ impl DpmcCompany for DpmcCompanyService { None => self.protocol.print_id_map(), } - let resp = self.protocol + let resp = self + .protocol .save_features_shares(&self.output_shares_path.clone().unwrap()) - .map(|_| { - Response::new(CommitmentAck {}) - }) + .map(|_| Response::new(CommitmentAck {})) .map_err(|_| Status::internal("error saving feature shares")); { debug!("Setting up flag for graceful down"); diff --git a/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs b/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs index 3cf3c07..7546cdd 100644 --- a/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs +++ b/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs @@ -1,23 +1,37 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -use std::{sync::{atomic::{AtomicBool, Ordering}, Arc,},}; -use tonic::{Code, Request, Response, Status, Streaming, transport::Channel}; -use crypto::spoint::ByteBuffer; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + use common::timer; -use protocol::dpmc::{ - partner::PartnerDpmc, traits::PartnerDpmcProtocol, -}; -use rpc::proto::{ - common::Payload, - gen_dpmc_partner::{ - dpmc_partner_server::DpmcPartner, service_response::*, - Commitment, CommitmentAck, Init, InitAck, SendData, SendDataAck, - ServiceResponse, CompanyPublicKeyAck, HelperPublicKeyAck - }, - gen_dpmc_company::dpmc_company_client::DpmcCompanyClient, - streaming::{read_from_stream, write_to_stream, send_data, TPayloadStream}, -}; +use crypto::spoint::ByteBuffer; +use protocol::dpmc::partner::PartnerDpmc; +use protocol::dpmc::traits::PartnerDpmcProtocol; +use rpc::proto::common::Payload; +use rpc::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use rpc::proto::gen_dpmc_partner::dpmc_partner_server::DpmcPartner; +use rpc::proto::gen_dpmc_partner::service_response::*; +use rpc::proto::gen_dpmc_partner::Commitment; +use rpc::proto::gen_dpmc_partner::CommitmentAck; +use rpc::proto::gen_dpmc_partner::CompanyPublicKeyAck; +use rpc::proto::gen_dpmc_partner::HelperPublicKeyAck; +use rpc::proto::gen_dpmc_partner::Init; +use rpc::proto::gen_dpmc_partner::InitAck; +use rpc::proto::gen_dpmc_partner::SendData; +use rpc::proto::gen_dpmc_partner::SendDataAck; +use rpc::proto::gen_dpmc_partner::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::transport::Channel; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; pub struct DpmcPartnerService { protocol: PartnerDpmc, @@ -55,14 +69,20 @@ impl DpmcPartner for DpmcPartnerService { .label("server") .extra_label("init") .build(); - self.protocol - .load_data(&self.input_keys_path, &self.input_features_path, self.input_with_headers); + self.protocol.load_data( + &self.input_keys_path, + &self.input_features_path, + self.input_with_headers, + ); Ok(Response::new(ServiceResponse { ack: Some(Ack::InitAck(InitAck {})), })) } - async fn send_data_to_company(&self, _: Request) -> Result, Status> { + async fn send_data_to_company( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("init") @@ -75,11 +95,13 @@ impl DpmcPartner for DpmcPartnerService { let xor_shares = self.protocol.get_features_xor_shares().unwrap(); let xor_shares_len = xor_shares.len(); h_partner_alpha.extend(xor_shares); - h_partner_alpha.push( - ByteBuffer{ buffer: (xor_shares_len as u64).to_le_bytes().to_vec(), } - ); + h_partner_alpha.push(ByteBuffer { + buffer: (xor_shares_len as u64).to_le_bytes().to_vec(), + }); - _ = company_client_contxt.send_u_partner(send_data(h_partner_alpha)).await; + _ = company_client_contxt + .send_u_partner(send_data(h_partner_alpha)) + .await; Ok(Response::new(ServiceResponse { ack: Some(Ack::SendDataAck(SendDataAck {})), @@ -88,7 +110,7 @@ impl DpmcPartner for DpmcPartnerService { async fn recv_partner_public_key( &self, - _: Request + _: Request, ) -> Result, Status> { let _ = timer::Builder::new() .label("server") @@ -138,7 +160,10 @@ impl DpmcPartner for DpmcPartnerService { .map_err(|_| Status::internal("error writing")) } - async fn stop_service(&self, _: Request) -> Result, Status> { + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("stop") diff --git a/protocol-rpc/src/rpc/dspmc/client.rs b/protocol-rpc/src/rpc/dspmc/client.rs index f590bd4..d7ae9e4 100644 --- a/protocol-rpc/src/rpc/dspmc/client.rs +++ b/protocol-rpc/src/rpc/dspmc/client.rs @@ -1,42 +1,35 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -use clap::{App, Arg, ArgGroup}; -use log::info; use std::convert::TryInto; -use tonic::Request; +use clap::App; +use clap::Arg; +use clap::ArgGroup; use common::timer; use crypto::prelude::TPayload; -use protocol::dspmc::{shuffler::ShufflerDspmc, traits::*}; +use log::info; +use protocol::dspmc::shuffler::ShufflerDspmc; +use protocol::dspmc::traits::*; use protocol::shared::TFeatures; -use rpc::{ - connect::create_client::create_client, - proto::{ - gen_dspmc_company::{ - service_response::Ack as CompanyAck, - Init as CompanyInit, - ServiceResponse as CompanyServiceResponse, - SendData as CompanySendData, - RecvShares as CompanyRecvShares, - }, - gen_dspmc_partner::{ - service_response::Ack as PartnerAck, - Init as PartnerInit, - SendData as PartnerSendData, - }, - gen_dspmc_helper::{ - service_response::Ack as HelperAck, - ServiceResponse as HelperServiceResponse, - SendDataAck - }, - RpcClient, - }, -}; +use rpc::connect::create_client::create_client; +use rpc::proto::gen_dspmc_company::service_response::Ack as CompanyAck; +use rpc::proto::gen_dspmc_company::Init as CompanyInit; +use rpc::proto::gen_dspmc_company::RecvShares as CompanyRecvShares; +use rpc::proto::gen_dspmc_company::SendData as CompanySendData; +use rpc::proto::gen_dspmc_company::ServiceResponse as CompanyServiceResponse; +use rpc::proto::gen_dspmc_helper::service_response::Ack as HelperAck; +use rpc::proto::gen_dspmc_helper::SendDataAck; +use rpc::proto::gen_dspmc_helper::ServiceResponse as HelperServiceResponse; +use rpc::proto::gen_dspmc_partner::service_response::Ack as PartnerAck; +use rpc::proto::gen_dspmc_partner::Init as PartnerInit; +use rpc::proto::gen_dspmc_partner::SendData as PartnerSendData; +use rpc::proto::RpcClient; +use tonic::Request; mod rpc_client_company; -mod rpc_client_partner; mod rpc_client_helper; +mod rpc_client_partner; #[tokio::main] async fn main() -> Result<(), Box> { @@ -112,9 +105,7 @@ async fn main() -> Result<(), Box> { ArgGroup::with_name("tls") .args(&["no-tls", "tls-dir", "tls-key"]) .required(true), - ArgGroup::with_name("out") - .args(&["stdout"]) - .required(true), + ArgGroup::with_name("out").args(&["stdout"]).required(true), ]) .get_matches(); @@ -228,24 +219,26 @@ async fn main() -> Result<(), Box> { .await?; shuffler_protocol.set_company_public_key(company_public_key.clone())?; - let helper_public_key_ack = - match rpc_client_helper::send( - company_public_key.clone(), - "company_public_key".to_string(), - &mut helper_client_context) - .await? - .into_inner() - .ack - .unwrap() - { - HelperAck::CompanyPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let helper_public_key_ack = match rpc_client_helper::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; let mut helper_public_key = TPayload::new(); let _ = rpc_client_helper::recv( HelperServiceResponse { - ack: Some(HelperAck::CompanyPublicKeyAck(helper_public_key_ack.clone())), + ack: Some(HelperAck::CompanyPublicKeyAck( + helper_public_key_ack.clone(), + )), }, "helper_public_key".to_string(), &mut helper_public_key, @@ -255,50 +248,50 @@ async fn main() -> Result<(), Box> { shuffler_protocol.set_helper_public_key(helper_public_key.clone())?; // Send helper public key to Company - let _ = - match rpc_client_company::send( - helper_public_key.clone(), - "helper_public_key".to_string(), - &mut company_client_context) - .await? - .into_inner() - .ack - .unwrap() - { - CompanyAck::HelperPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_company::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut company_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; for i in 0..partner_client_context.len() { // Send company public key to partners - let _ = - match rpc_client_partner::send( - company_public_key.clone(), - "company_public_key".to_string(), - &mut partner_client_context[i]) - .await? - .into_inner() - .ack - .unwrap() - { - PartnerAck::CompanyPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_partner::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; // Send helper public key to partners - let _ = - match rpc_client_partner::send( - helper_public_key.clone(), - "helper_public_key".to_string(), - &mut partner_client_context[i]) - .await? - .into_inner() - .ack - .unwrap() - { - PartnerAck::HelperPublicKeyAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_partner::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; } } @@ -350,13 +343,11 @@ async fn main() -> Result<(), Box> { ) .await?; - let offset_len = u64::from_le_bytes( - v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let offset_len = + u64::from_le_bytes(v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // flattened len - let data_len = u64::from_le_bytes( - v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let data_len = + u64::from_le_bytes(v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_keys = offset_len - 1; let offset = v4_p4 .drain((num_keys * 2 + data_len * 2)..) @@ -364,16 +355,10 @@ async fn main() -> Result<(), Box> { .collect::>(); assert_eq!(offset_len, offset.len()); - let ct2_prime_flat = v4_p4 - .drain((v4_p4.len()-data_len)..) - .collect::>(); - let ct1_prime_flat = v4_p4 - .drain((v4_p4.len()-data_len)..) - .collect::>(); + let ct2_prime_flat = v4_p4.drain((v4_p4.len() - data_len)..).collect::>(); + let ct1_prime_flat = v4_p4.drain((v4_p4.len() - data_len)..).collect::>(); - let v_cs_bytes = v4_p4 - .drain((v4_p4.len()-num_keys)..) - .collect::>(); + let v_cs_bytes = v4_p4.drain((v4_p4.len() - num_keys)..).collect::>(); v4_p4.shrink_to_fit(); shuffler_protocol.set_p_cs_v_cs(v_cs_bytes, v4_p4)?; @@ -401,7 +386,7 @@ async fn main() -> Result<(), Box> { .drain(i * num_rows..) .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) .collect::>(); - u2.push(x); + u2.push(x); } // 10. Generete shuffler permutations @@ -411,39 +396,43 @@ async fn main() -> Result<(), Box> { // 11. Compute x_2 = p_cs(u2) xor v_cs // Compute v_2' = p_sd(p_sc(x_2) xor v_sd) xor v_sd // Return rerandomized ct1' and ct2' as ct1'' and ct2'' - let ct1_ct2_dprime = shuffler_protocol.compute_v2prime_ct1ct2( - u2, ct1_prime_flat, ct2_prime_flat, offset - ).unwrap(); + let ct1_ct2_dprime = shuffler_protocol + .compute_v2prime_ct1ct2(u2, ct1_prime_flat, ct2_prime_flat, offset) + .unwrap(); // v_sc, p_sc, ct1_dprime_flat, ct2_dprime_flat, ct_offset let mut p_sc_v_sc_ct1_ct2_dprime = p_sc_v_sc; p_sc_v_sc_ct1_ct2_dprime.extend(ct1_ct2_dprime); // 12. Send v_sc, p_sc, ct1'', ct2'' to C - let _company_p_sc_v_sc_ack = - match rpc_client_company::send(p_sc_v_sc_ct1_ct2_dprime, - "p_sc_v_sc_ct1_ct2_dprime".to_string(), &mut company_client_context) - .await? - .into_inner() - .ack - .unwrap() - { - CompanyAck::UPartnerAck(x) => x, - _ => panic!("wrong ack"), - }; + let _company_p_sc_v_sc_ack = match rpc_client_company::send( + p_sc_v_sc_ct1_ct2_dprime, + "p_sc_v_sc_ct1_ct2_dprime".to_string(), + &mut company_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; // 13. Send p_sd, v_sd to helper (D) - let _ = - match rpc_client_helper::send(p_sd_v_sd, "p_sd_v_sd".to_string(), - &mut helper_client_context) - .await? - .into_inner() - .ack - .unwrap() - { - HelperAck::UPartnerAck(x) => x, - _ => panic!("wrong ack"), - }; + let _ = match rpc_client_helper::send( + p_sd_v_sd, + "p_sd_v_sd".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; // 14. Send request to company to send u1 to Helper // u1 = p_sc( p_cs( p_cd(v_1) xor v_cd) xor v_cs) xor v_sc @@ -465,17 +454,19 @@ async fn main() -> Result<(), Box> { let blinded_vprime = shuffler_protocol.get_blinded_vprime().unwrap(); // 15. Send blinded v' and g^z to helper (D) - let _helper_vprime_ack = - match rpc_client_helper::send(blinded_vprime, "encrypted_vprime".to_string(), - &mut helper_client_context) - .await? - .into_inner() - .ack - .unwrap() - { - HelperAck::UPartnerAck(x) => x, - _ => panic!("wrong ack"), - }; + let _helper_vprime_ack = match rpc_client_helper::send( + blinded_vprime, + "encrypted_vprime".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; // 16. Send request to company to send ct1, ct2', and X to Helper // ct2' = ct2^c diff --git a/protocol-rpc/src/rpc/dspmc/company-server.rs b/protocol-rpc/src/rpc/dspmc/company-server.rs index 239b83c..a0c8696 100644 --- a/protocol-rpc/src/rpc/dspmc/company-server.rs +++ b/protocol-rpc/src/rpc/dspmc/company-server.rs @@ -7,26 +7,24 @@ extern crate clap; extern crate ctrlc; extern crate tonic; -use clap::{App, Arg, ArgGroup}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use log::info; -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread, time, -}; -mod rpc_server_company; mod rpc_client_helper; +mod rpc_server_company; -use rpc::{ - connect::{ create_server::create_server, create_client::create_client, }, - proto::{ - gen_dspmc_company::dspmc_company_server, - RpcClient, - }, -}; +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_company::dspmc_company_server; +use rpc::proto::RpcClient; #[tokio::main] async fn main() -> Result<(), Box> { @@ -70,8 +68,10 @@ async fn main() -> Result<(), Box> { .long("output-shares-path") .takes_value(true) .required(true) - .help("path to write shares of features.\n - Feature will be written as {path}_partner_features.csv"), + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), Arg::with_name("no-tls") .long("no-tls") .takes_value(false) @@ -170,8 +170,11 @@ async fn main() -> Result<(), Box> { } let service = rpc_server_company::DspmcCompanyService::new( - input_path, output_keys_path, output_shares_path, input_with_headers, - helper_client_context + input_path, + output_keys_path, + output_shares_path, + input_with_headers, + helper_client_context, ); let ks = service.killswitch.clone(); @@ -190,9 +193,7 @@ async fn main() -> Result<(), Box> { let addr = host.unwrap().parse()?; server - .add_service(dspmc_company_server::DspmcCompanyServer::new( - service, - )) + .add_service(dspmc_company_server::DspmcCompanyServer::new(service)) .serve_with_shutdown(addr, async { rx.await.ok(); }) diff --git a/protocol-rpc/src/rpc/dspmc/helper-server.rs b/protocol-rpc/src/rpc/dspmc/helper-server.rs index c5cff3c..013ccc6 100644 --- a/protocol-rpc/src/rpc/dspmc/helper-server.rs +++ b/protocol-rpc/src/rpc/dspmc/helper-server.rs @@ -7,21 +7,20 @@ extern crate clap; extern crate ctrlc; extern crate tonic; -use clap::{App, Arg, ArgGroup}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use log::info; -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread, time, -}; mod rpc_server_helper; -use rpc::{ - connect::create_server::create_server, - proto::gen_dspmc_helper::dspmc_helper_server, -}; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_helper::dspmc_helper_server; #[tokio::main] async fn main() -> Result<(), Box> { @@ -50,8 +49,10 @@ async fn main() -> Result<(), Box> { .long("output-shares-path") .takes_value(true) .required(true) - .help("path to write shares of features.\n - Feature will be written as {path}_partner_features.csv"), + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), Arg::with_name("no-tls") .long("no-tls") .takes_value(false) @@ -127,8 +128,7 @@ async fn main() -> Result<(), Box> { error!("Output shares path not provided"); } - let service = - rpc_server_helper::DspmcHelperService::new(output_keys_path, output_shares_path); + let service = rpc_server_helper::DspmcHelperService::new(output_keys_path, output_shares_path); let ks = service.killswitch.clone(); let recv_thread = thread::spawn(move || { @@ -146,9 +146,7 @@ async fn main() -> Result<(), Box> { let addr = host.unwrap().parse()?; server - .add_service(dspmc_helper_server::DspmcHelperServer::new( - service, - )) + .add_service(dspmc_helper_server::DspmcHelperServer::new(service)) .serve_with_shutdown(addr, async { rx.await.ok(); }) diff --git a/protocol-rpc/src/rpc/dspmc/partner-server.rs b/protocol-rpc/src/rpc/dspmc/partner-server.rs index f8b1770..8af486f 100644 --- a/protocol-rpc/src/rpc/dspmc/partner-server.rs +++ b/protocol-rpc/src/rpc/dspmc/partner-server.rs @@ -7,26 +7,24 @@ extern crate clap; extern crate ctrlc; extern crate tonic; -use clap::{App, Arg, ArgGroup}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; use log::info; -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread, time, -}; -mod rpc_server_partner; mod rpc_client_company; +mod rpc_server_partner; -use rpc::{ - connect::{ create_server::create_server, create_client::create_client, }, - proto::{ - gen_dspmc_partner::dspmc_partner_server, - RpcClient, - }, -}; +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_partner::dspmc_partner_server; +use rpc::proto::RpcClient; #[tokio::main] async fn main() -> Result<(), Box> { @@ -97,15 +95,15 @@ async fn main() -> Result<(), Box> { .takes_value(true) .help("Override TLS domain for SSL cert (if host is IP)"), ]) - .groups(&[ - ArgGroup::with_name("tls") - .args(&["no-tls", "tls-dir", "tls-key"]) - .required(true), - ]) + .groups(&[ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true)]) .get_matches(); let input_keys_path = matches.value_of("input-keys").unwrap_or("input_keys.csv"); - let input_features_path = matches.value_of("input-features").unwrap_or("input_features.csv"); + let input_features_path = matches + .value_of("input-features") + .unwrap_or("input_features.csv"); let input_with_headers = matches.is_present("input-with-headers"); let no_tls = matches.is_present("no-tls"); @@ -145,8 +143,12 @@ async fn main() -> Result<(), Box> { info!("Input path for keys: {}", input_keys_path); info!("Input path for features: {}", input_features_path); - let service = - rpc_server_partner::DspmcPartnerService::new(input_keys_path, input_features_path, input_with_headers, company_client_context); + let service = rpc_server_partner::DspmcPartnerService::new( + input_keys_path, + input_features_path, + input_with_headers, + company_client_context, + ); let ks = service.killswitch.clone(); let recv_thread = thread::spawn(move || { @@ -164,7 +166,7 @@ async fn main() -> Result<(), Box> { let addr = host.unwrap().parse()?; server - .add_service(dspmc_partner_server::DspmcPartnerServer::new(service,)) + .add_service(dspmc_partner_server::DspmcPartnerServer::new(service)) .serve_with_shutdown(addr, async { rx.await.ok(); }) diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs index dbb3ef4..7b08d17 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs @@ -5,16 +5,17 @@ extern crate common; extern crate crypto; extern crate protocol; -use tonic::{transport::Channel, Request, Response, Status}; - use common::timer; use crypto::prelude::TPayload; -use rpc::proto::{ - gen_dspmc_company::{ - dspmc_company_client::DspmcCompanyClient, Commitment, ServiceResponse, - }, - streaming::{read_from_stream, send_data}, -}; +use rpc::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use rpc::proto::gen_dspmc_company::Commitment; +use rpc::proto::gen_dspmc_company::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; pub async fn send( data: TPayload, @@ -52,6 +53,9 @@ pub async fn recv( } pub async fn calculate_id_map(rpc: &mut DspmcCompanyClient) -> Result<(), Status> { - let _r = rpc.calculate_id_map(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs index 9601fcb..84a3c98 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs @@ -5,16 +5,17 @@ extern crate common; extern crate crypto; extern crate protocol; -use tonic::{transport::Channel, Request, Response, Status}; - use common::timer; use crypto::prelude::TPayload; -use rpc::proto::{ - gen_dspmc_helper::{ - dspmc_helper_client::DspmcHelperClient, Commitment, ServiceResponse, - }, - streaming::{send_data, read_from_stream}, -}; +use rpc::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use rpc::proto::gen_dspmc_helper::Commitment; +use rpc::proto::gen_dspmc_helper::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; pub async fn send( data: TPayload, @@ -52,7 +53,10 @@ pub async fn recv( } pub async fn calculate_id_map(rpc: &mut DspmcHelperClient) -> Result<(), Status> { - let _r = rpc.calculate_id_map(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } @@ -62,6 +66,9 @@ pub async fn reveal(rpc: &mut DspmcHelperClient) -> Result<(), Status> } pub async fn stop_service(rpc: &mut DspmcHelperClient) -> Result<(), Status> { - let _r = rpc.stop_service(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs index c20ddd9..aabe369 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs @@ -5,15 +5,15 @@ extern crate common; extern crate crypto; extern crate protocol; -use tonic::{transport::Channel, Request, Response, Status}; - use crypto::prelude::TPayload; -use rpc::proto::{ - gen_dspmc_partner::{ - dspmc_partner_client::DspmcPartnerClient, Commitment, ServiceResponse, - }, - streaming::send_data, -}; +use rpc::proto::gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; +use rpc::proto::gen_dspmc_partner::Commitment; +use rpc::proto::gen_dspmc_partner::ServiceResponse; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; pub async fn send( data: TPayload, @@ -28,6 +28,9 @@ pub async fn send( } pub async fn stop_service(rpc: &mut DspmcPartnerClient) -> Result<(), Status> { - let _r = rpc.stop_service(Request::new(Commitment {})).await?.into_inner(); + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); Ok(()) } diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs index 837a1c2..80df38d 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs @@ -8,35 +8,42 @@ extern crate protocol; extern crate tokio; extern crate tonic; -use std::{ - borrow::BorrowMut, - convert::TryInto, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; -use tonic::{Code, Request, Response, Status, Streaming, transport::Channel}; +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; use common::timer; -use protocol::{ - dspmc::{company::CompanyDspmc, traits::CompanyDspmcProtocol}, - shared::TFeatures, -}; -use rpc::proto::{ - common::Payload, - gen_dspmc_company::{ - dspmc_company_server::DspmcCompany, service_response::*, - Commitment, CommitmentAck, Init, InitAck, SendData, - SendDataAck, ServiceResponse, UPartnerAck, - RecvShares, RecvSharesAck, HelperPublicKeyAck - }, - gen_dspmc_helper::{ - dspmc_helper_client::DspmcHelperClient, - ServiceResponse as HelperServiceResponse, - }, - streaming::{read_from_stream, write_to_stream, send_data, TPayloadStream}, -}; +use protocol::dspmc::company::CompanyDspmc; +use protocol::dspmc::traits::CompanyDspmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_company::dspmc_company_server::DspmcCompany; +use rpc::proto::gen_dspmc_company::service_response::*; +use rpc::proto::gen_dspmc_company::Commitment; +use rpc::proto::gen_dspmc_company::CommitmentAck; +use rpc::proto::gen_dspmc_company::HelperPublicKeyAck; +use rpc::proto::gen_dspmc_company::Init; +use rpc::proto::gen_dspmc_company::InitAck; +use rpc::proto::gen_dspmc_company::RecvShares; +use rpc::proto::gen_dspmc_company::RecvSharesAck; +use rpc::proto::gen_dspmc_company::SendData; +use rpc::proto::gen_dspmc_company::SendDataAck; +use rpc::proto::gen_dspmc_company::ServiceResponse; +use rpc::proto::gen_dspmc_company::UPartnerAck; +use rpc::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use rpc::proto::gen_dspmc_helper::ServiceResponse as HelperServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::transport::Channel; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; pub struct DspmcCompanyService { protocol: CompanyDspmc, @@ -86,7 +93,10 @@ impl DspmcCompany for DspmcCompanyService { })) } - async fn send_ct3_p_cd_v_cd_to_helper(&self, _: Request) -> Result, Status> { + async fn send_ct3_p_cd_v_cd_to_helper( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("send_ct3_p_cd_v_cd_to_helper") @@ -97,14 +107,19 @@ impl DspmcCompany for DspmcCompanyService { // Send ct3 from all partners to helper. - company acts as a client to helper. let partners_ct3 = self.protocol.get_all_ct3_p_cd_v_cd().unwrap(); let mut helper_client_contxt = self.helper_client_context.clone(); - _ = helper_client_contxt.send_ct3_p_cd_v_cd(send_data(partners_ct3)).await; + _ = helper_client_contxt + .send_ct3_p_cd_v_cd(send_data(partners_ct3)) + .await; Ok(Response::new(ServiceResponse { ack: Some(Ack::SendDataAck(SendDataAck {})), })) } - async fn send_u1_to_helper(&self, _: Request) -> Result, Status> { + async fn send_u1_to_helper( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("send_u1_to_helper") @@ -120,7 +135,10 @@ impl DspmcCompany for DspmcCompanyService { })) } - async fn send_encrypted_keys_to_helper(&self, _: Request) -> Result, Status> { + async fn send_encrypted_keys_to_helper( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("send_encrypted_keys_to_helper") @@ -134,14 +152,19 @@ impl DspmcCompany for DspmcCompanyService { let mut helper_client_contxt = self.helper_client_context.clone(); // X, offset, metadata, ct1, ct2, offset, metadata - _ = helper_client_contxt.send_encrypted_keys(send_data(enc_keys)).await; + _ = helper_client_contxt + .send_encrypted_keys(send_data(enc_keys)) + .await; Ok(Response::new(ServiceResponse { ack: Some(Ack::SendDataAck(SendDataAck {})), })) } - async fn recv_shares_from_helper(&self, _: Request) -> Result, Status> { + async fn recv_shares_from_helper( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("recv_shares_from_helper") @@ -149,16 +172,26 @@ impl DspmcCompany for DspmcCompanyService { let mut helper_client_contxt = self.helper_client_context.clone(); let request = Request::new(HelperServiceResponse { - ack: Some(rpc::proto::gen_dspmc_helper::service_response::Ack::UPartnerAck(rpc::proto::gen_dspmc_helper::UPartnerAck {})), + ack: Some( + rpc::proto::gen_dspmc_helper::service_response::Ack::UPartnerAck( + rpc::proto::gen_dspmc_helper::UPartnerAck {}, + ), + ), }); - let mut strm = helper_client_contxt.recv_xor_shares(request).await?.into_inner(); + let mut strm = helper_client_contxt + .recv_xor_shares(request) + .await? + .into_inner(); let mut data = read_from_stream(&mut strm).await?; let num_features = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_rows = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let g_zi = data.drain(num_features * num_rows..).map(|x| x).collect::>(); + let g_zi = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); let mut features = TFeatures::new(); for i in (0..num_features).rev() { @@ -177,7 +210,8 @@ impl DspmcCompany for DspmcCompanyService { None => self.protocol.print_id_map(), } - let resp = self.protocol + let resp = self + .protocol .save_features_shares(&self.output_shares_path.clone().unwrap()) .map(|_| { Response::new(ServiceResponse { @@ -193,20 +227,24 @@ impl DspmcCompany for DspmcCompanyService { resp } - async fn calculate_id_map(&self, _: Request) -> Result, Status> { + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("calculate_id_map") .build(); self.protocol .write_company_to_id_map() - .map(|_| { - Response::new(CommitmentAck {}) - }) + .map(|_| Response::new(CommitmentAck {})) .map_err(|_| Status::new(Code::Aborted, "cannot init the protocol for partner")) } - async fn recv_company_public_key(&self, _: Request) -> Result, Status> { + async fn recv_company_public_key( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("recv_company_public_key") @@ -255,13 +293,11 @@ impl DspmcCompany for DspmcCompanyService { .build(); let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; - let offset_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // flattened len - let data_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_keys = offset_len - 1; let offset = data @@ -270,16 +306,10 @@ impl DspmcCompany for DspmcCompanyService { .collect::>(); assert_eq!(offset_len, offset.len()); - let ct2_dprime_flat = data - .drain((data.len()-data_len)..) - .collect::>(); - let ct1_dprime_flat = data - .drain((data.len()-data_len)..) - .collect::>(); + let ct2_dprime_flat = data.drain((data.len() - data_len)..).collect::>(); + let ct1_dprime_flat = data.drain((data.len() - data_len)..).collect::>(); - let v_sc_bytes = data - .drain((data.len()-num_keys)..) - .collect::>(); + let v_sc_bytes = data.drain((data.len() - num_keys)..).collect::>(); data.shrink_to_fit(); // p_sc self.protocol @@ -340,7 +370,13 @@ impl DspmcCompany for DspmcCompanyService { assert_eq!(offset_len, offset.len()); self.protocol - .set_encrypted_partner_keys_and_shares(ct1.to_vec(), ct2.to_vec(), offset, ct3.buffer, xor_features) + .set_encrypted_partner_keys_and_shares( + ct1.to_vec(), + ct2.to_vec(), + offset, + ct3.buffer, + xor_features, + ) .map(|_| { Response::new(ServiceResponse { ack: Some(Ack::UPartnerAck(UPartnerAck {})), @@ -367,6 +403,4 @@ impl DspmcCompany for DspmcCompanyService { }) .map_err(|_| Status::internal("error writing")) } - - } diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs index 20c7228..81dfc5d 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs @@ -8,30 +8,33 @@ extern crate protocol; extern crate tokio; extern crate tonic; -use std::{ - borrow::BorrowMut, - convert::TryInto, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; -use tonic::{Code, Request, Response, Status, Streaming}; +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; use common::timer; -use protocol::{ - dspmc::{helper::HelperDspmc, traits::HelperDspmcProtocol}, - shared::TFeatures, -}; -use rpc::proto::{ - common::Payload, - gen_dspmc_helper::{ - dspmc_helper_server::DspmcHelper, service_response::*, - Commitment, CommitmentAck, EHelperAck, ServiceResponse, - UPartnerAck, CompanyPublicKeyAck - }, - streaming::{read_from_stream, write_to_stream, TPayloadStream}, -}; +use protocol::dspmc::helper::HelperDspmc; +use protocol::dspmc::traits::HelperDspmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_helper::dspmc_helper_server::DspmcHelper; +use rpc::proto::gen_dspmc_helper::service_response::*; +use rpc::proto::gen_dspmc_helper::Commitment; +use rpc::proto::gen_dspmc_helper::CommitmentAck; +use rpc::proto::gen_dspmc_helper::CompanyPublicKeyAck; +use rpc::proto::gen_dspmc_helper::EHelperAck; +use rpc::proto::gen_dspmc_helper::ServiceResponse; +use rpc::proto::gen_dspmc_helper::UPartnerAck; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; pub struct DspmcHelperService { protocol: HelperDspmc, @@ -79,7 +82,10 @@ impl DspmcHelper for DspmcHelperService { .map_err(|_| Status::internal("error writing")) } - async fn calculate_id_map(&self, _: Request) -> Result, Status> { + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("calculate_id_map") @@ -89,7 +95,10 @@ impl DspmcHelper for DspmcHelperService { Ok(Response::new(CommitmentAck {})) } - async fn recv_helper_public_key(&self, _: Request) -> Result, Status> { + async fn recv_helper_public_key( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("recv_helper_public_key") @@ -102,7 +111,7 @@ impl DspmcHelper for DspmcHelperService { async fn recv_xor_shares( &self, - _: Request + _: Request, ) -> Result, Status> { let _ = timer::Builder::new() .label("server") @@ -116,7 +125,7 @@ impl DspmcHelper for DspmcHelperService { async fn recv_u2( &self, - _: Request + _: Request, ) -> Result, Status> { let _ = timer::Builder::new() .label("server") @@ -142,12 +151,8 @@ impl DspmcHelper for DspmcHelperService { let data_len = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let v_cd_bytes = data - .drain((data.len()-data_len)..) - .collect::>(); - let p_cd_bytes = data - .drain((data.len()-data_len)..) - .collect::>(); + let v_cd_bytes = data.drain((data.len() - data_len)..).collect::>(); + let p_cd_bytes = data.drain((data.len() - data_len)..).collect::>(); let num_partners = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; @@ -175,9 +180,7 @@ impl DspmcHelper for DspmcHelperService { let data_len = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let v_sd_bytes = data - .drain((data.len()-data_len)..) - .collect::>(); + let v_sd_bytes = data.drain((data.len() - data_len)..).collect::>(); data.shrink_to_fit(); self.protocol @@ -204,7 +207,10 @@ impl DspmcHelper for DspmcHelperService { u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_rows = u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let g_zi = data.drain(num_features * num_rows..).map(|x| x).collect::>(); + let g_zi = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); let mut blinded_features = TFeatures::new(); for i in (0..num_features).rev() { @@ -247,7 +253,7 @@ impl DspmcHelper for DspmcHelperService { .drain(i * num_rows..) .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) .collect::>(); - u1.push(x); + u1.push(x); } self.protocol @@ -272,13 +278,11 @@ impl DspmcHelper for DspmcHelperService { // X, offset, metadata, ct1, ct2, offset, metadata let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; - let ct_offset_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let ct_offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // flattened len - let ct_data_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let ct_data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // let num_keys = ct_offset_len - 1; let ct_offset = data @@ -287,21 +291,15 @@ impl DspmcHelper for DspmcHelperService { .collect::>(); assert_eq!(ct_offset_len, ct_offset.len()); - let ct2_flat = data - .drain((data.len()-ct_data_len)..) - .collect::>(); - let ct1_flat = data - .drain((data.len()-ct_data_len)..) - .collect::>(); + let ct2_flat = data.drain((data.len() - ct_data_len)..).collect::>(); + let ct1_flat = data.drain((data.len() - ct_data_len)..).collect::>(); // H(C)*c - let offset_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // flattened len - let data_len = u64::from_le_bytes( - data.pop().unwrap().buffer.as_slice().try_into().unwrap(), - ) as usize; + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; // let num_keys = offset_len - 1; let offset = data @@ -331,11 +329,10 @@ impl DspmcHelper for DspmcHelperService { None => self.protocol.print_id_map(), } - let resp = self.protocol + let resp = self + .protocol .save_features_shares(&self.output_shares_path.clone().unwrap()) - .map(|_| { - Response::new(CommitmentAck {}) - }) + .map(|_| Response::new(CommitmentAck {})) .map_err(|_| Status::internal("error saving feature shares")); { debug!("Setting up flag for graceful down"); @@ -345,7 +342,10 @@ impl DspmcHelper for DspmcHelperService { resp } - async fn stop_service(&self, _: Request) -> Result, Status> { + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("stop") diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs index ac637d1..aa6c1f6 100644 --- a/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs @@ -8,27 +8,33 @@ extern crate protocol; extern crate tokio; extern crate tonic; -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; -use tonic::{Request, Response, Status, Streaming, transport::Channel}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + use common::timer; -use protocol::dspmc::{ - partner::PartnerDspmc, traits::PartnerDspmcProtocol, -}; -use rpc::proto::{ - common::Payload, - gen_dspmc_partner::{ - dspmc_partner_server::DspmcPartner, service_response::*, - Commitment, CommitmentAck, Init, InitAck, SendData, SendDataAck, - ServiceResponse, CompanyPublicKeyAck, HelperPublicKeyAck - }, - gen_dspmc_company::dspmc_company_client::DspmcCompanyClient, - streaming::{read_from_stream, send_data}, -}; +use protocol::dspmc::partner::PartnerDspmc; +use protocol::dspmc::traits::PartnerDspmcProtocol; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use rpc::proto::gen_dspmc_partner::dspmc_partner_server::DspmcPartner; +use rpc::proto::gen_dspmc_partner::service_response::*; +use rpc::proto::gen_dspmc_partner::Commitment; +use rpc::proto::gen_dspmc_partner::CommitmentAck; +use rpc::proto::gen_dspmc_partner::CompanyPublicKeyAck; +use rpc::proto::gen_dspmc_partner::HelperPublicKeyAck; +use rpc::proto::gen_dspmc_partner::Init; +use rpc::proto::gen_dspmc_partner::InitAck; +use rpc::proto::gen_dspmc_partner::SendData; +use rpc::proto::gen_dspmc_partner::SendDataAck; +use rpc::proto::gen_dspmc_partner::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; pub struct DspmcPartnerService { protocol: PartnerDspmc, @@ -59,20 +65,25 @@ impl DspmcPartnerService { #[tonic::async_trait] impl DspmcPartner for DspmcPartnerService { - async fn initialize(&self, _: Request) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("init") .build(); - self.protocol - .load_data(&self.input_keys_path, &self.input_features_path, self.input_with_headers); + self.protocol.load_data( + &self.input_keys_path, + &self.input_features_path, + self.input_with_headers, + ); Ok(Response::new(ServiceResponse { ack: Some(Ack::InitAck(InitAck {})), })) } - async fn send_data_to_company(&self, _: Request) -> Result, Status> { + async fn send_data_to_company( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("init") @@ -88,7 +99,9 @@ impl DspmcPartner for DspmcPartnerService { ct1_ct2.extend(xor_shares); // ct2 + ct1 + offset + XOR shares + metadata + ct3 - _ = company_client_contxt.send_u_partner(send_data(ct1_ct2)).await; + _ = company_client_contxt + .send_u_partner(send_data(ct1_ct2)) + .await; Ok(Response::new(ServiceResponse { ack: Some(Ack::SendDataAck(SendDataAck {})), @@ -133,7 +146,10 @@ impl DspmcPartner for DspmcPartnerService { .map_err(|_| Status::internal("error writing")) } - async fn stop_service(&self, _: Request) -> Result, Status> { + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { let _ = timer::Builder::new() .label("server") .extra_label("stop") diff --git a/protocol/src/dpmc/company.rs b/protocol/src/dpmc/company.rs index 27b85cc..3a73f96 100644 --- a/protocol/src/dpmc/company.rs +++ b/protocol/src/dpmc/company.rs @@ -3,24 +3,29 @@ extern crate csv; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; -use std::{ - collections::{ HashMap, VecDeque }, - convert::TryInto, - path::Path, - sync::{Arc, RwLock}, -}; -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; -use common::{ - permutations::{gen_permute_pattern, permute, undo_permute}, - timer, -}; -use super::{load_data_keys, serialize_helper, writer_helper, ProtocolError}; -use crate::{dpmc::traits::CompanyDpmcProtocol, shared::TFeatures,}; +use super::load_data_keys; +use super::serialize_helper; +use super::writer_helper; +use super::ProtocolError; +use crate::dpmc::traits::CompanyDpmcProtocol; +use crate::shared::TFeatures; #[derive(Debug)] struct PartnerData { @@ -63,13 +68,12 @@ impl CompanyDpmc { } pub fn get_company_public_key(&self) -> Result { - Ok(self.ec_cipher.to_bytes(&vec![self.keypair_pk])) + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) } pub fn load_data(&self, path: &str, input_with_headers: bool) { load_data_keys(self.plaintext.clone(), path, input_with_headers); } - } impl Default for CompanyDpmc { @@ -96,9 +100,7 @@ impl CompanyDpmcProtocol for CompanyDpmc { // Unflatten let pdata = { - let t = self - .ec_cipher - .to_points_encrypt(&data, &self.private_beta); + let t = self.ec_cipher.to_points_encrypt(&data, &self.private_beta); psum.get(0..num_keys) .unwrap() @@ -110,11 +112,11 @@ impl CompanyDpmcProtocol for CompanyDpmc { t.qps("deserialize_exp", pdata.len()); - partners_queue.push_back(PartnerData{ - enc_alpha_t: enc_alpha_t, - scalar_g: scalar_g, + partners_queue.push_back(PartnerData { + enc_alpha_t, + scalar_g, partner_enc_shares: xor_shares, - e_partner: pdata + e_partner: pdata, }); Ok(()) @@ -122,7 +124,7 @@ impl CompanyDpmcProtocol for CompanyDpmc { _ => { error!("Cannot load e_partner"); Err(ProtocolError::ErrorDeserialization( - "cannot load e_partner".to_string(), + "cannot load e_partner".to_string(), )) } } @@ -148,7 +150,9 @@ impl CompanyDpmcProtocol for CompanyDpmc { let (d_flat, offset, metadata) = serialize_helper(d); // Encrypt - let x = self.ec_cipher.hash_encrypt(d_flat.as_slice(), &self.private_beta); + let x = self + .ec_cipher + .hash_encrypt(d_flat.as_slice(), &self.private_beta); (x, offset, metadata) }; @@ -157,11 +161,14 @@ impl CompanyDpmcProtocol for CompanyDpmc { { let psum = offset .iter() - .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .map(|b| { + u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize + }) .collect::>(); let num_keys = psum.len() - 1; - let mut x = psum.get(0..num_keys) + let mut x = psum + .get(0..num_keys) .unwrap() .iter() .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) @@ -222,13 +229,17 @@ impl CompanyDpmcProtocol for CompanyDpmc { d_flat.extend(offset); // Append encrypted key alpha - d_flat.push(ByteBuffer{ buffer: enc_a_t.to_vec(), } ); + d_flat.push(ByteBuffer { + buffer: enc_a_t.to_vec(), + }); - d_flat.push(ByteBuffer{ buffer: scalar_g.to_vec(), } ); + d_flat.push(ByteBuffer { + buffer: scalar_g.to_vec(), + }); // Append offsets array d_flat.extend(enc_shares.clone()); - d_flat.push(ByteBuffer{ + d_flat.push(ByteBuffer { buffer: (enc_shares.len() as u64).to_le_bytes().to_vec(), }); @@ -246,7 +257,7 @@ impl CompanyDpmcProtocol for CompanyDpmc { fn calculate_features_xor_shares( &self, partner_features: TFeatures, - p_mask_d: TPayload + p_mask_d: TPayload, ) -> Result<(), ProtocolError> { match self.partner_shares.clone().write() { Ok(mut shares) => { @@ -256,18 +267,17 @@ impl CompanyDpmcProtocol for CompanyDpmc { let mask = p_mask .iter() .map(|x| { - let t = self.ec_cipher.to_bytes(&vec![x * &self.keypair_sk]); + let t = self.ec_cipher.to_bytes(&[x * self.keypair_sk]); u64::from_le_bytes((t[0].buffer[0..8]).try_into().unwrap()) }) .collect::>(); for f_idx in 0..n_features { - let s = - partner_features[f_idx] - .iter() - .zip_eq(mask.iter()) - .map(|(x1, x2)| *x1 ^ *x2) - .collect::>(); + let s = partner_features[f_idx] + .iter() + .zip_eq(mask.iter()) + .map(|(x1, x2)| *x1 ^ *x2) + .collect::>(); shares.insert(f_idx, s); } @@ -294,10 +304,7 @@ impl CompanyDpmcProtocol for CompanyDpmc { // Get the first column. let company_keys = { - let tmp = company_ragged - .iter() - .map(|s| s[0]) - .collect::>(); + let tmp = company_ragged.iter().map(|s| s[0]).collect::>(); self.ec_cipher.to_bytes(tmp.as_slice()) }; @@ -314,7 +321,7 @@ impl CompanyDpmcProtocol for CompanyDpmc { _ => { error!("Cannot create id_map"); Err(ProtocolError::ErrorDeserialization( - "cannot create id_map".to_string() + "cannot create id_map".to_string(), )) } } diff --git a/protocol/src/dpmc/helper.rs b/protocol/src/dpmc/helper.rs index 794294f..deb0f8d 100644 --- a/protocol/src/dpmc/helper.rs +++ b/protocol/src/dpmc/helper.rs @@ -1,30 +1,34 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -extern crate csv; extern crate base64; +extern crate csv; -use std::{ - collections::{ HashMap, HashSet }, - convert::TryInto, - path::Path, - sync::{Arc, RwLock}, -}; -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; +use std::collections::HashMap; +use std::collections::HashSet; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; -use common::{ - permutations::{gen_permute_pattern, permute, undo_permute}, - timer, -}; -use rand::{distributions::Uniform, Rng,}; -use rayon::iter::{ParallelDrainRange, ParallelIterator}; -use super::{ - writer_helper_dpmc, ProtocolError -}; -use crate::{dpmc::traits::HelperDpmcProtocol, shared::TFeatures,}; +use rand::distributions::Uniform; +use rand::Rng; +use rayon::iter::ParallelDrainRange; +use rayon::iter::ParallelIterator; + +use super::writer_helper_dpmc; +use super::ProtocolError; +use crate::dpmc::traits::HelperDpmcProtocol; +use crate::shared::TFeatures; #[derive(Debug)] struct PartnerData { @@ -68,8 +72,9 @@ impl HelperDpmc { } } - pub fn set_company_public_key(&self, - company_public_key: TPayload + pub fn set_company_public_key( + &self, + company_public_key: TPayload, ) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&company_public_key); // Check that one key is sent @@ -78,7 +83,7 @@ impl HelperDpmc { match self.company_public_key.clone().write() { Ok(mut company_pk) => { *company_pk = pk[0]; - assert_eq!((*company_pk).is_identity(), false); + assert!(!(*company_pk).is_identity()); Ok(()) } _ => { @@ -91,9 +96,8 @@ impl HelperDpmc { } pub fn get_helper_public_key(&self) -> Result { - Ok(self.ec_cipher.to_bytes(&vec![self.keypair_pk])) + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) } - } impl Default for HelperDpmc { @@ -105,19 +109,22 @@ impl Default for HelperDpmc { fn decrypt_shares(mut enc_t: TPayload, aes_key: String) -> (TFeatures, TPayload) { let mut t = { let fernet = fernet::Fernet::new(&aes_key).unwrap(); - enc_t.par_drain(..).map(|x| { - let ctxt_str = String::from_utf8(x.buffer).unwrap(); - ByteBuffer{ - buffer: fernet.decrypt(&ctxt_str).unwrap().to_vec() - } - }).collect::>() + enc_t + .par_drain(..) + .map(|x| { + let ctxt_str = String::from_utf8(x.buffer).unwrap(); + ByteBuffer { + buffer: fernet.decrypt(&ctxt_str).unwrap().to_vec(), + } + }) + .collect::>() }; let num_features = u64::from_le_bytes(t.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; let num_rows = u64::from_le_bytes(t.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let g_zi = t.drain(num_features * num_rows..).map(|x| x).collect::>(); + let g_zi = t.drain(num_features * num_rows..).collect::>(); let mut features = TFeatures::new(); for i in (0..num_features).rev() { @@ -132,14 +139,13 @@ fn decrypt_shares(mut enc_t: TPayload, aes_key: String) -> (TFeatures, TPayload) } impl HelperDpmcProtocol for HelperDpmc { - fn remove_partner_scalar_from_p_and_set_shares( &self, data: TPayload, psum: Vec, enc_alpha_t: Vec, p_scalar_g: TPayload, - xor_shares: TPayload + xor_shares: TPayload, ) -> Result<(), ProtocolError> { match ( self.partners_data.clone().write(), @@ -150,21 +156,26 @@ impl HelperDpmcProtocol for HelperDpmc { let aes_key = { let aes_key_bytes = { - let x = self.ec_cipher.to_points_encrypt(&p_scalar_g, &self.keypair_sk); + let x = self + .ec_cipher + .to_points_encrypt(&p_scalar_g, &self.keypair_sk); let y = self.ec_cipher.to_bytes(&x); y[0].buffer.clone() }; - base64::encode_config(&aes_key_bytes, base64::URL_SAFE) + base64::encode_config(aes_key_bytes, base64::URL_SAFE) }; let alpha_t = { - let ctxt_str: String = - String::from_utf8(enc_alpha_t.clone()).unwrap(); + let ctxt_str: String = String::from_utf8(enc_alpha_t.clone()).unwrap(); Scalar::from_bits( - fernet::Fernet::new(&aes_key).unwrap() - .decrypt(&ctxt_str).unwrap() - .to_vec()[0..32].try_into().unwrap() + fernet::Fernet::new(&aes_key) + .unwrap() + .decrypt(&ctxt_str) + .unwrap() + .to_vec()[0..32] + .try_into() + .unwrap(), ) }; @@ -174,9 +185,7 @@ impl HelperDpmcProtocol for HelperDpmc { // Unflatten let pdata = { - let t = self - .ec_cipher - .to_points_encrypt(&data, &alpha_t.invert()); + let t = self.ec_cipher.to_points_encrypt(&data, &alpha_t.invert()); psum.get(0..num_keys) .unwrap() @@ -190,13 +199,13 @@ impl HelperDpmcProtocol for HelperDpmc { let (features, g_zi) = decrypt_shares(xor_shares, aes_key); - partners_data.push(PartnerData{ + partners_data.push(PartnerData { h_b_partner: pdata, - features: features, - g_zi: g_zi, + features, + g_zi, }); - set_diffs.push(SetDiff{ + set_diffs.push(SetDiff { s_company: HashSet::::new(), s_partner: HashSet::::new(), }); @@ -206,14 +215,16 @@ impl HelperDpmcProtocol for HelperDpmc { _ => { error!("Cannot load e_company"); Err(ProtocolError::ErrorDeserialization( - "cannot load h_b_partner".to_string(), + "cannot load h_b_partner".to_string(), )) } } } - fn set_encrypted_company(&self, - company: TPayload, company_psum: Vec + fn set_encrypted_company( + &self, + company: TPayload, + company_psum: Vec, ) -> Result<(), ProtocolError> { match (self.h_company_beta.clone().write(),) { (Ok(mut h_company_beta),) => { @@ -222,7 +233,8 @@ impl HelperDpmcProtocol for HelperDpmc { h_company_beta.clear(); let e_company = { let t = self.ec_cipher.to_points(&company); - company_psum.get(0..num_keys) + company_psum + .get(0..num_keys) .unwrap() .iter() .zip_eq(company_psum.get(1..num_keys + 1).unwrap().iter()) @@ -412,10 +424,8 @@ impl HelperDpmcProtocol for HelperDpmc { }; // Create a hashmap for all unique partner keys that are not in S_Partner - let mut unique_partner_ids: - HashMap> = HashMap::new(); + let mut unique_partner_ids: HashMap> = HashMap::new(); for p in 0..set_diffs.len() { - // Get the first column. let partner_keys = { let tmp = { @@ -431,8 +441,10 @@ impl HelperDpmcProtocol for HelperDpmc { // if not in S_Partner if !set_diffs[p].s_partner.contains(&key.to_string()) { // if not already in the id map - if !unique_partner_ids.contains_key( &key.to_string() ) { - unique_partner_ids.insert(key.to_string(), vec![(idx, p)]); + if let std::collections::hash_map::Entry::Vacant(e) = + unique_partner_ids.entry(key.to_string()) + { + e.insert(vec![(idx, p)]); } else { let v = unique_partner_ids.get_mut(&key.to_string()).unwrap(); if v.len() < num_of_matches { @@ -445,21 +457,28 @@ impl HelperDpmcProtocol for HelperDpmc { // Add each item of unique_partner_ids into id_map. id_map.clear(); id_map.extend({ - let x = unique_partner_ids.iter_mut().map(|(key, v)| { - v.resize(num_of_matches, (0, 0)); - v.iter() - .map(|(idx, from_p)| (key.to_string(), *idx, true, *from_p)) - .collect::>() - }).collect::>(); + let x = unique_partner_ids + .iter_mut() + .map(|(key, v)| { + v.resize(num_of_matches, (0, 0)); + v.iter() + .map(|(idx, from_p)| (key.to_string(), *idx, true, *from_p)) + .collect::>() + }) + .collect::>(); x.into_iter().flatten().collect::>() }); // Add all the remaining keys that company has but the partners don't. id_map.extend({ - let x = sc_intersection.clone().iter().map(|key| { - (0..num_of_matches) - .map(|_| (key.to_string(), 0, false, 0)) - .collect::>() - }).collect::>(); + let x = sc_intersection + .clone() + .iter() + .map(|key| { + (0..num_of_matches) + .map(|_| (key.to_string(), 0, false, 0)) + .collect::>() + }) + .collect::>(); x.into_iter().flatten().collect::>() }); @@ -490,54 +509,54 @@ impl HelperDpmcProtocol for HelperDpmc { .unwrap(); let (t_i, mut g_zi) = { - let z_i = (0..id_map.len()) - .map(|_| gen_scalar()) - .collect::>(); - let x = z_i.iter() + let z_i = (0..id_map.len()).map(|_| gen_scalar()).collect::>(); + let x = z_i + .iter() .map(|a| { - let x = - self.ec_cipher.to_bytes(&vec![a * *company_public_key]); + let x = self.ec_cipher.to_bytes(&[a * *company_public_key]); x[0].clone() - }).collect::>(); - let y = z_i.iter() + }) + .collect::>(); + let y = z_i + .iter() .map(|a| a * &RISTRETTO_BASEPOINT_TABLE) .collect::>(); (x, y) }; let mut d_flat = { - - let p_mask_v = partners_data. - iter() - .map(|p_data| self.ec_cipher.to_points(&*p_data.g_zi)) + let p_mask_v = partners_data + .iter() + .map(|p_data| self.ec_cipher.to_points(&p_data.g_zi)) .collect::>(); let mut v_p = Vec::::new(); for f_idx in (0..n_features).rev() { let mask = (0..id_map.len()) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); let t = id_map .iter() .enumerate() .map(|(i, (_, idx, exists, from_partner))| { - let y = - if *exists { - if f_idx == 0 { - g_zi[i] = p_mask_v[*from_partner][*idx]; - } - let partner_features = &partners_data[*from_partner].features; - if f_idx < partner_features.len() { - partner_features[f_idx][*idx] ^ mask[i] - } else { - // In case the data are not padded correctly, - // return secret shares of the first feature. - partner_features[0][*idx] ^ mask[i] - } + let y = if *exists { + if f_idx == 0 { + g_zi[i] = p_mask_v[*from_partner][*idx]; + } + let partner_features = &partners_data[*from_partner].features; + if f_idx < partner_features.len() { + partner_features[f_idx][*idx] ^ mask[i] } else { - let y = u64::from_le_bytes((t_i[i].buffer[0..8]).try_into().unwrap()); - y ^ mask[i] - }; + // In case the data are not padded correctly, + // return secret shares of the first feature. + partner_features[0][*idx] ^ mask[i] + } + } else { + let y = u64::from_le_bytes( + (t_i[i].buffer[0..8]).try_into().unwrap(), + ); + y ^ mask[i] + }; ByteBuffer { buffer: y.to_le_bytes().to_vec(), } @@ -559,7 +578,7 @@ impl HelperDpmcProtocol for HelperDpmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); @@ -586,7 +605,7 @@ impl HelperDpmcProtocol for HelperDpmc { .max() .unwrap(); - let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx+1]; + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; for i in 0..id_map.len() { let (_, idx, flag, _) = id_map[i]; @@ -612,7 +631,7 @@ impl HelperDpmcProtocol for HelperDpmc { .max() .unwrap(); - let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx+1]; + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; for i in 0..id_map.len() { let (_, idx, flag, _) = id_map[i]; diff --git a/protocol/src/dpmc/mod.rs b/protocol/src/dpmc/mod.rs index a665be2..128831e 100644 --- a/protocol/src/dpmc/mod.rs +++ b/protocol/src/dpmc/mod.rs @@ -3,8 +3,14 @@ extern crate csv; -use std::{collections::HashSet, sync::{Arc, RwLock}, error::Error, fmt}; -use common::{files, timer}; +use std::collections::HashSet; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use std::sync::RwLock; + +use common::files; +use common::timer; use crypto::prelude::*; #[derive(Debug)] @@ -34,8 +40,10 @@ fn load_data_keys(plaintext: Arc>>>, path: &str, input_wi if let Ok(mut data) = plaintext.write() { data.clear(); let mut line_it = lines.drain(..); - // Strip the header for now - if input_with_headers && line_it.next().is_some() {} + // Strip the header + if input_with_headers { + line_it.next(); + } let mut t = HashSet::>::new(); // Filter out zero length strings - these will come from ragged @@ -91,7 +99,11 @@ fn load_data_features(plaintext: Arc>>>, path: &str) { t.qps("text read", n_rows); } -fn writer_helper_dpmc(data: &[Vec], id_map: &[(String, usize, bool, usize)], path: Option) { +fn writer_helper_dpmc( + data: &[Vec], + id_map: &[(String, usize, bool, usize)], + path: Option, +) { let mut device = match path { Some(path) => { let wr = csv::WriterBuilder::new() diff --git a/protocol/src/dpmc/partner.rs b/protocol/src/dpmc/partner.rs index 4006fb7..1b5cde5 100644 --- a/protocol/src/dpmc/partner.rs +++ b/protocol/src/dpmc/partner.rs @@ -3,15 +3,26 @@ extern crate base64; -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; -use common::{timer, permutations::{gen_permute_pattern, permute},}; -use rayon::iter::{ParallelDrainRange, ParallelIterator}; -use std::{convert::TryInto, sync::{Arc, RwLock}}; -use super::{load_data_keys, load_data_features, serialize_helper, ProtocolError}; -use crate::{dpmc::traits::PartnerDpmcProtocol, shared::TFeatures}; +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use rayon::iter::ParallelDrainRange; +use rayon::iter::ParallelIterator; + +use super::load_data_features; +use super::load_data_keys; +use super::serialize_helper; +use super::ProtocolError; +use crate::dpmc::traits::PartnerDpmcProtocol; +use crate::shared::TFeatures; pub struct PartnerDpmc { keypair_sk: Scalar, @@ -67,10 +78,13 @@ impl PartnerDpmc { } pub fn get_partner_public_key(&self) -> Result { - Ok(self.ec_cipher.to_bytes(&vec![self.keypair_pk])) + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) } - pub fn set_company_public_key(&self, company_public_key: TPayload) -> Result<(), ProtocolError> { + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&company_public_key); // Check that one key is sent assert_eq!(pk.len(), 1); @@ -78,7 +92,7 @@ impl PartnerDpmc { match self.company_public_key.clone().write() { Ok(mut company_pk) => { *company_pk = pk[0]; - assert_eq!((*company_pk).is_identity(), false); + assert!(!(*company_pk).is_identity()); Ok(()) } _ => { @@ -96,16 +110,18 @@ impl PartnerDpmc { assert_eq!(pk.len(), 1); match ( self.helper_public_key.clone().write(), - self.aes_key.clone().write() + self.aes_key.clone().write(), ) { (Ok(mut helper_pk), Ok(mut aes_key)) => { *helper_pk = pk[0]; - assert_eq!((*helper_pk).is_identity(), false); + assert!(!(*helper_pk).is_identity()); *aes_key = { - let x = self.ec_cipher.to_bytes(&vec![self.partner_scalar * (*helper_pk)]); + let x = self + .ec_cipher + .to_bytes(&[self.partner_scalar * (*helper_pk)]); let aes_key_bytes = x[0].buffer.clone(); - base64::encode_config(&aes_key_bytes, base64::URL_SAFE) + base64::encode_config(aes_key_bytes, base64::URL_SAFE) }; Ok(()) } @@ -117,7 +133,6 @@ impl PartnerDpmc { } } } - } impl Default for PartnerDpmc { @@ -151,7 +166,8 @@ impl PartnerDpmcProtocol for PartnerDpmc { // Encrypt ( // Blind the keys by encrypting - self.ec_cipher.hash_encrypt_to_bytes(d_flat.as_slice(), &self.keypair_sk), + self.ec_cipher + .hash_encrypt_to_bytes(d_flat.as_slice(), &self.keypair_sk), offset, ) }; @@ -165,12 +181,12 @@ impl PartnerDpmcProtocol for PartnerDpmc { let ctxt = fernet.encrypt(self.keypair_sk.to_bytes().clone().as_slice()); // Append encrypted key alpha d_flat.push(ByteBuffer { - buffer: ctxt.as_bytes().to_vec() + buffer: ctxt.as_bytes().to_vec(), }); - let p_scalar_times_g = self.ec_cipher.to_bytes( - &vec![&self.partner_scalar * &RISTRETTO_BASEPOINT_TABLE] - ); + let p_scalar_times_g = self + .ec_cipher + .to_bytes(&[&self.partner_scalar * &RISTRETTO_BASEPOINT_TABLE]); d_flat.extend(p_scalar_times_g); Ok(d_flat) @@ -202,13 +218,11 @@ impl PartnerDpmcProtocol for PartnerDpmc { permute(permutation.as_slice(), &mut permuted_pdata[i]); } - let z_i = - (0..n_rows) - .map(|x| x) - .collect::>() - .iter() - .map(|_| gen_scalar()) - .collect::>(); + let z_i = (0..n_rows) + .collect::>() + .iter() + .map(|_| gen_scalar()) + .collect::>(); let mut d_flat = { let r_i = { @@ -219,8 +233,7 @@ impl PartnerDpmcProtocol for PartnerDpmc { .collect::>(); self.ec_cipher.to_bytes(&t) }; - y_zi - .iter() + y_zi.iter() .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>() }; @@ -238,7 +251,7 @@ impl PartnerDpmcProtocol for PartnerDpmc { buffer: z.to_le_bytes().to_vec(), } }) - .collect::>(); + .collect::>(); v_p.push(t); } @@ -262,18 +275,21 @@ impl PartnerDpmcProtocol for PartnerDpmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); let e_d_flat = { - let fernet = fernet::Fernet::new(&(*aes_key.clone())).unwrap(); - d_flat.par_drain(..).map(|x| { - let ctxt = fernet.encrypt(x.buffer.as_slice()); - ByteBuffer { - buffer: ctxt.as_bytes().to_vec() - } - }).collect::>() + let fernet = fernet::Fernet::new(&aes_key.clone()).unwrap(); + d_flat + .par_drain(..) + .map(|x| { + let ctxt = fernet.encrypt(x.buffer.as_slice()); + ByteBuffer { + buffer: ctxt.as_bytes().to_vec(), + } + }) + .collect::>() }; t.qps("e_d_flat", e_d_flat.len()); diff --git a/protocol/src/dpmc/traits.rs b/protocol/src/dpmc/traits.rs index 5b3a1cb..5ce690e 100644 --- a/protocol/src/dpmc/traits.rs +++ b/protocol/src/dpmc/traits.rs @@ -3,12 +3,11 @@ extern crate crypto; -use crate::{ - dpmc::ProtocolError, - shared::TFeatures, -}; use crypto::prelude::TPayload; +use crate::dpmc::ProtocolError; +use crate::shared::TFeatures; + pub trait PartnerDpmcProtocol { fn get_encrypted_keys(&self) -> Result; fn get_features_xor_shares(&self) -> Result; @@ -21,12 +20,14 @@ pub trait HelperDpmcProtocol { psum: Vec, enc_alpha_t: Vec, p_scalar_g: TPayload, - xor_shares: TPayload + xor_shares: TPayload, ) -> Result<(), ProtocolError>; fn calculate_set_diff(&self, partner_num: usize) -> Result<(), ProtocolError>; fn calculate_id_map(&self, calculate_id_map: usize); - fn set_encrypted_company(&self, - company: TPayload, company_psum: Vec + fn set_encrypted_company( + &self, + company: TPayload, + company_psum: Vec, ) -> Result<(), ProtocolError>; fn calculate_features_xor_shares(&self) -> Result; fn print_id_map(&self); @@ -45,7 +46,11 @@ pub trait CompanyDpmcProtocol { ) -> Result<(), ProtocolError>; fn get_permuted_keys(&self) -> Result; fn serialize_encrypted_keys_and_features(&self) -> Result; - fn calculate_features_xor_shares(&self, features: TFeatures, data: TPayload) -> Result<(), ProtocolError>; + fn calculate_features_xor_shares( + &self, + features: TFeatures, + data: TPayload, + ) -> Result<(), ProtocolError>; fn write_company_to_id_map(&self) -> Result<(), ProtocolError>; fn print_id_map(&self); fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; diff --git a/protocol/src/dspmc/company.rs b/protocol/src/dspmc/company.rs index cd36849..4ca1b70 100644 --- a/protocol/src/dspmc/company.rs +++ b/protocol/src/dspmc/company.rs @@ -3,35 +3,31 @@ extern crate csv; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; -use std::{ - collections::{ HashMap, VecDeque }, - convert::TryInto, - path::Path, - sync::{Arc, RwLock}, -}; - -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; - -use common::{ - permutations::{gen_permute_pattern, permute, undo_permute}, - timer, -}; - -use rand::{ - distributions::Uniform, - Rng, -}; - -use crate::{ - dspmc::traits::CompanyDspmcProtocol, - shared::TFeatures, -}; - -use super::{load_data_keys, serialize_helper, writer_helper, ProtocolError}; +use rand::distributions::Uniform; +use rand::Rng; + +use super::load_data_keys; +use super::serialize_helper; +use super::writer_helper; +use super::ProtocolError; +use crate::dspmc::traits::CompanyDspmcProtocol; +use crate::shared::TFeatures; #[derive(Debug)] struct PartnerData { @@ -52,8 +48,8 @@ pub struct CompanyDspmc { u1: Arc>>, plaintext: Arc>>>, permutation: Arc>>, - perms: Arc, Vec)>>, // (p_3, p_4) - blinds: Arc, Vec)>>, // (v_cd, v_cs) + perms: Arc, Vec)>>, // (p_3, p_4) + blinds: Arc, Vec)>>, // (v_cd, v_cs) enc_company: Arc>>>, partners_queue: Arc>>, id_map: Arc>>, @@ -66,7 +62,10 @@ impl CompanyDspmc { let x2 = gen_scalar(); CompanyDspmc { keypair_sk: (x1, x2), - keypair_pk: (&x1 * &RISTRETTO_BASEPOINT_TABLE, &x2 * &RISTRETTO_BASEPOINT_TABLE), + keypair_pk: ( + &x1 * &RISTRETTO_BASEPOINT_TABLE, + &x2 * &RISTRETTO_BASEPOINT_TABLE, + ), helper_public_key: Arc::new(RwLock::default()), ec_cipher: ECRistrettoParallel::default(), ct1: Arc::new(RwLock::default()), @@ -85,7 +84,9 @@ impl CompanyDspmc { } pub fn get_company_public_key(&self) -> Result { - Ok(self.ec_cipher.to_bytes(&vec![self.keypair_pk.0, self.keypair_pk.1])) + Ok(self + .ec_cipher + .to_bytes(&vec![self.keypair_pk.0, self.keypair_pk.1])) } pub fn load_data(&self, path: &str, input_with_headers: bool) { @@ -110,10 +111,10 @@ impl CompanyDspmc { perms.1.extend(gen_permute_pattern(data_len)); blinds.0 = (0..data_len) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); blinds.1 = (0..data_len) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); } _ => {} @@ -137,7 +138,6 @@ impl CompanyDspmc { } } } - } impl Default for CompanyDspmc { @@ -147,7 +147,6 @@ impl Default for CompanyDspmc { } impl CompanyDspmcProtocol for CompanyDspmc { - fn set_encrypted_partner_keys_and_shares( &self, ct1: TPayload, @@ -162,8 +161,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { self.ct2.clone().write(), self.v1.clone().write(), ) { - (Ok(mut partners_queue), Ok(mut all_ct1), - Ok(mut all_ct2), Ok(mut all_v1)) => { + (Ok(mut partners_queue), Ok(mut all_ct1), Ok(mut all_ct2), Ok(mut all_v1)) => { let t = timer::Timer::new_silent("load_ct2"); // This is an array of exclusive-inclusive prefix sum - hence // number of keys is one less than length @@ -209,17 +207,17 @@ impl CompanyDspmcProtocol for CompanyDspmc { } } - partners_queue.push_back(PartnerData{ + partners_queue.push_back(PartnerData { scalar_g: ct3, - n_rows: n_rows, - n_features: n_features, + n_rows, + n_features, }); Ok(()) } _ => { error!("Cannot load ct2"); Err(ProtocolError::ErrorDeserialization( - "cannot load ct2".to_string(), + "cannot load ct2".to_string(), )) } } @@ -228,7 +226,10 @@ impl CompanyDspmcProtocol for CompanyDspmc { // Get dataset C with company keys and encrypt them to H(C)^c // With Elliptic curves: H(C)*c fn get_company_keys(&self) -> Result { - match (self.plaintext.clone().read(), self.enc_company.clone().write(),) { + match ( + self.plaintext.clone().read(), + self.enc_company.clone().write(), + ) { (Ok(pdata), Ok(mut enc_company)) => { let t = timer::Timer::new_silent("x_company"); @@ -237,7 +238,9 @@ impl CompanyDspmcProtocol for CompanyDspmc { let (d_flat, offset, metadata) = serialize_helper(pdata.to_vec()); // Hash Encrypt - H(C)^c - let enc = self.ec_cipher.hash_encrypt(d_flat.as_slice(), &self.keypair_sk.0); + let enc = self + .ec_cipher + .hash_encrypt(d_flat.as_slice(), &self.keypair_sk.0); (enc, offset, metadata) }; @@ -246,11 +249,14 @@ impl CompanyDspmcProtocol for CompanyDspmc { { let psum = offset .iter() - .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .map(|b| { + u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize + }) .collect::>(); let num_keys = psum.len() - 1; - let mut x = psum.get(0..num_keys) + let mut x = psum + .get(0..num_keys) .unwrap() .iter() .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) @@ -289,8 +295,8 @@ impl CompanyDspmcProtocol for CompanyDspmc { // Get dataset ct1 and ct2' fn get_ct1_ct2(&self) -> Result { - match (self.ct1.clone().read(), self.ct2.clone().read(),) { - (Ok(ct1), Ok(ct2),) => { + match (self.ct1.clone().read(), self.ct2.clone().read()) { + (Ok(ct1), Ok(ct2)) => { let t = timer::Timer::new_silent("x_company"); // Re-randomize ct1'' and ct2'' and flatten @@ -309,7 +315,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { let d_flat_c = d_flat .iter() - .map(|x| *x * (&self.keypair_sk.0)) + .map(|x| *x * (self.keypair_sk.0)) .collect::>(); (self.ec_cipher.to_bytes(d_flat_c.as_slice()), offset) @@ -350,35 +356,39 @@ impl CompanyDspmcProtocol for CompanyDspmc { let n_rows = partner_data.n_rows; let n_features = partner_data.n_features; - res.push(ByteBuffer{ buffer: ct3 }); + res.push(ByteBuffer { buffer: ct3 }); let metadata = vec![ ByteBuffer { buffer: (n_rows as u64).to_le_bytes().to_vec(), }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; res.extend(metadata); } - res.push(ByteBuffer{ + res.push(ByteBuffer { buffer: (num_partners as u64).to_le_bytes().to_vec(), }); - let p_cd_bytes = perms.0.iter() + let p_cd_bytes = perms + .0 + .iter() .map(|e| ByteBuffer { - buffer: (*e as u64).to_le_bytes().to_vec(), + buffer: (*e).to_le_bytes().to_vec(), }) .collect::>(); - let v_cd_bytes = blinds.0.iter() + let v_cd_bytes = blinds + .0 + .iter() .map(|e| ByteBuffer { - buffer: (*e as u64).to_le_bytes().to_vec(), + buffer: (*e).to_le_bytes().to_vec(), }) .collect::>(); let data_len = p_cd_bytes.len(); res.extend(p_cd_bytes); res.extend(v_cd_bytes); - res.push(ByteBuffer{ + res.push(ByteBuffer { buffer: (data_len as u64).to_le_bytes().to_vec(), }); @@ -395,19 +405,25 @@ impl CompanyDspmcProtocol for CompanyDspmc { fn get_p_cs_v_cs(&self) -> Result { match ( - self.perms.clone().read(), self.blinds.clone().read(), - self.ct1.clone().write(), self.ct2.clone().write(), + self.perms.clone().read(), + self.blinds.clone().read(), + self.ct1.clone().write(), + self.ct2.clone().write(), self.helper_public_key.clone().read(), ) { (Ok(perms), Ok(blinds), Ok(mut ct1), Ok(mut ct2), Ok(helper_pk)) => { let mut res = vec![]; - let p_cs_bytes = perms.1.iter() + let p_cs_bytes = perms + .1 + .iter() .map(|e| ByteBuffer { buffer: (*e as u64).to_le_bytes().to_vec(), }) .collect::>(); - let v_cs_bytes = blinds.1.iter() + let v_cs_bytes = blinds + .1 + .iter() .map(|e| ByteBuffer { buffer: (*e as u64).to_le_bytes().to_vec(), }) @@ -419,7 +435,6 @@ impl CompanyDspmcProtocol for CompanyDspmc { // Re-randomize ct1 and ct2 to ct1' and ct2' let (ct1_prime_flat, ct2_prime_flat, ct_offset) = { let r_i = (0..data_len) - .map(|x| x) .collect::>() .iter() .map(|_| gen_scalar()) @@ -428,14 +443,11 @@ impl CompanyDspmcProtocol for CompanyDspmc { // with EC: company_pk * r let pkc_r = r_i .iter() - .map(|x| *x * (&self.keypair_pk.0)) + .map(|x| *x * (self.keypair_pk.0)) .collect::>(); // helper_pk^r // with EC: helper_pk * r - let pkd_r = r_i - .iter() - .map(|x| *x * (*helper_pk)) - .collect::>(); + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); permute(perms.0.as_slice(), &mut ct1); // p_cd permute(perms.0.as_slice(), &mut ct2); // p_cd @@ -447,18 +459,14 @@ impl CompanyDspmcProtocol for CompanyDspmc { let ct1_prime = ct1 .iter() .zip_eq(pkc_r.iter()) - .map(|(s, t)| { - (*s).iter().map(|si| *si + *t).collect::>() - }) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) .collect::>(); // ct2' = p_4(p_3(ct2)) * helper_pk^r // with EC: ct2' = p_4(p_3(ct2)) + helper_pk*r let ct2_prime = ct2 .iter() .zip_eq(pkd_r.iter()) - .map(|(s, t)| { - (*s).iter().map(|si| *si + *t).collect::>() - }) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) .collect::>(); let (ct1_prime_flat, _ct1_offset) = { @@ -503,21 +511,22 @@ impl CompanyDspmcProtocol for CompanyDspmc { psum: Vec, ) -> Result<(), ProtocolError> { match ( - self.ct1.clone().write(), self.ct2.clone().write(), - self.perms.clone().read(), self.blinds.clone().read(), - self.v1.clone().write(), self.u1.clone().write() + self.ct1.clone().write(), + self.ct2.clone().write(), + self.perms.clone().read(), + self.blinds.clone().read(), + self.v1.clone().write(), + self.u1.clone().write(), ) { - ( - Ok(mut ct1), Ok(mut ct2), Ok(perms), - Ok(blinds), Ok(mut v1), Ok(mut u1) - ) => { + (Ok(mut ct1), Ok(mut ct2), Ok(perms), Ok(blinds), Ok(mut v1), Ok(mut u1)) => { let t = timer::Timer::new_silent("set set_p_sc_v_sc_ct1ct2dprime"); let num_keys = v_sc_bytes.len(); // Remove the previous data and replace them with the (doubly) re-randomized ct1.clear(); ct2.clear(); // Unflatten and convert to points - *ct1 = { // ct1'' (doubly re-randomized ct1) + *ct1 = { + // ct1'' (doubly re-randomized ct1) let t = self.ec_cipher.to_points(&ct1_dprime_flat); psum.get(0..num_keys) @@ -528,7 +537,8 @@ impl CompanyDspmcProtocol for CompanyDspmc { .collect::>>() }; // Unflatten and convert to points - *ct2 = { // ct2'' (doubly re-randomized ct2) + *ct2 = { + // ct2'' (doubly re-randomized ct2) let t = self.ec_cipher.to_points(&ct2_dprime_flat); psum.get(0..num_keys) @@ -541,46 +551,43 @@ impl CompanyDspmcProtocol for CompanyDspmc { let p_sc = p_sc_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) .collect::>(); let v_sc = v_sc_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>(); let n_features = v1.len(); // Compute u1 = p_sc( p_cs( p_cd(v_1) xor v_cd) xor v_cs) xor v_sc (*u1).clear(); for f_idx in (0..n_features).rev() { - permute(perms.0.as_slice(), &mut v1[f_idx]); // p_cd + permute(perms.0.as_slice(), &mut v1[f_idx]); // p_cd let mut u2 = v1[f_idx] .iter() - .zip_eq(blinds.0.iter()) // v_cd + .zip_eq(blinds.0.iter()) // v_cd .map(|(s, v_cd)| *s ^ *v_cd) .collect::>(); - permute(perms.1.as_slice(), &mut u2); // p_cs + permute(perms.1.as_slice(), &mut u2); // p_cs let mut t1 = u2 .iter() - .zip_eq(blinds.1.iter()) // v_cs + .zip_eq(blinds.1.iter()) // v_cs .map(|(s, v_cs)| *s ^ *v_cs) .collect::>(); - permute(p_sc.as_slice(), &mut t1); // p_sc - (*u1).push(t1 - .iter() - .zip_eq(v_sc.iter()) - .map(|(s, v_sc)| { // v_sc - let y = *s ^ *v_sc; - ByteBuffer { - buffer: y.to_le_bytes().to_vec(), - } - }) - .collect::>() + permute(p_sc.as_slice(), &mut t1); // p_sc + (*u1).push( + t1.iter() + .zip_eq(v_sc.iter()) + .map(|(s, v_sc)| { + // v_sc + let y = *s ^ *v_sc; + ByteBuffer { + buffer: y.to_le_bytes().to_vec(), + } + }) + .collect::>(), ); } @@ -590,7 +597,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { _ => { error!("Cannot flatten ct1'' and ct2''"); Err(ProtocolError::ErrorDeserialization( - "cannot flatten ct1'' and ct2''".to_string(), + "cannot flatten ct1'' and ct2''".to_string(), )) } } @@ -612,7 +619,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); @@ -641,7 +648,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { let r = g_zi_pt .iter() .map(|x| { - let t = self.ec_cipher.to_bytes(&vec![x * &self.keypair_sk.0]); + let t = self.ec_cipher.to_bytes(&[x * self.keypair_sk.0]); u64::from_le_bytes((t[0].buffer[0..8]).try_into().unwrap()) }) .collect::>(); @@ -678,10 +685,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { // Get the first column. let company_keys = { - let tmp = company_ragged - .iter() - .map(|s| s[0]) - .collect::>(); + let tmp = company_ragged.iter().map(|s| s[0]).collect::>(); self.ec_cipher.to_bytes(tmp.as_slice()) }; @@ -698,7 +702,7 @@ impl CompanyDspmcProtocol for CompanyDspmc { _ => { error!("Cannot create id_map"); Err(ProtocolError::ErrorDeserialization( - "cannot create id_map".to_string() + "cannot create id_map".to_string(), )) } } diff --git a/protocol/src/dspmc/helper.rs b/protocol/src/dspmc/helper.rs index f982546..f58d9fa 100644 --- a/protocol/src/dspmc/helper.rs +++ b/protocol/src/dspmc/helper.rs @@ -3,36 +3,28 @@ extern crate csv; +use std::collections::HashMap; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; -use std::{ - collections::{ HashMap }, - convert::TryInto, - path::Path, - sync::{Arc, RwLock}, -}; - -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; - -use common::{ - permutations::{gen_permute_pattern, permute}, - timer, -}; - -use rand::{ - distributions::Uniform, - prelude::*, - Rng, -}; - -use crate::{ - dspmc::traits::HelperDspmcProtocol, - shared::TFeatures, -}; - -use super::{writer_helper, ProtocolError}; +use rand::distributions::Uniform; +use rand::prelude::*; +use rand::Rng; + +use super::writer_helper; +use super::ProtocolError; +use crate::dspmc::traits::HelperDspmcProtocol; +use crate::shared::TFeatures; #[derive(Debug)] pub struct HelperDspmc { @@ -40,15 +32,15 @@ pub struct HelperDspmc { keypair_pk: TPoint, ec_cipher: ECRistrettoParallel, company_public_key: Arc>, - xor_shares_v2: Arc>, // v2 = v xor v1 -- The shuffler has v1 - enc_company: Arc>>>, // H(C)^c - enc_partners: Arc>>>, // H(P)^c - features: Arc>, // v''' from shuffler + xor_shares_v2: Arc>, // v2 = v xor v1 -- The shuffler has v1 + enc_company: Arc>>>, // H(C)^c + enc_partners: Arc>>>, // H(P)^c + features: Arc>, // v''' from shuffler p_cd: Arc>>, v_cd: Arc>>, p_sd: Arc>>, v_sd: Arc>>, - shuffler_gz: Arc>>, // h = g^z from shuffler + shuffler_gz: Arc>>, // h = g^z from shuffler s_company: Arc>>, s_partner: Arc>>, id_map: Arc>>, @@ -80,11 +72,12 @@ impl HelperDspmc { } pub fn get_helper_public_key(&self) -> Result { - Ok(self.ec_cipher.to_bytes(&vec![self.keypair_pk])) + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) } - pub fn set_company_public_key(&self, - company_public_key: TPayload + pub fn set_company_public_key( + &self, + company_public_key: TPayload, ) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&company_public_key); // Check that two keys are sent @@ -92,10 +85,10 @@ impl HelperDspmc { match self.company_public_key.clone().write() { Ok(mut company_pk) => { - (*company_pk).0 = pk[0]; - (*company_pk).1 = pk[1]; - assert_eq!(((*company_pk).0).is_identity(), false); - assert_eq!(((*company_pk).1).is_identity(), false); + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); Ok(()) } _ => { @@ -106,7 +99,6 @@ impl HelperDspmc { } } } - } impl Default for HelperDspmc { @@ -115,9 +107,7 @@ impl Default for HelperDspmc { } } - impl HelperDspmcProtocol for HelperDspmc { - fn set_ct3p_cd_v_cd( &self, mut data: TPayload, @@ -134,21 +124,22 @@ impl HelperDspmcProtocol for HelperDspmc { let t = timer::Timer::new_silent("set v''"); for _ in 0..num_partners { // Data in form [(ct3, metadata), (ct3, metadata), ... ] - let n_features = - u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; - let n_rows = - u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let n_features = u64::from_le_bytes( + data.pop().unwrap().buffer.as_slice().try_into().unwrap(), + ) as usize; + let n_rows = u64::from_le_bytes( + data.pop().unwrap().buffer.as_slice().try_into().unwrap(), + ) as usize; - let ct3 = data - .drain((data.len() - 1)..) - .collect::>(); + let ct3 = data.drain((data.len() - 1)..).collect::>(); // PRG seed = scalar * PK_helper let seed = { let x = self.ec_cipher.to_points_encrypt(&ct3, &self.keypair_sk); &self.ec_cipher.to_bytes(&x)[0].buffer }; - let seed_array: [u8; 32] = seed.as_slice().try_into().expect("incorrect length"); + let seed_array: [u8; 32] = + seed.as_slice().try_into().expect("incorrect length"); let mut rng = StdRng::from_seed(seed_array); // Merge features from all partners together. Example: @@ -165,7 +156,6 @@ impl HelperDspmcProtocol for HelperDspmc { // Merged: [[10, 20, 30, 40], [11, 21, 31, 41], [12, 22, 32, 42]] for f_idx in 0..n_features { let t = (0..n_rows) - .map(|x| x) .collect::>() .iter() .map(|_| rng.gen::()) @@ -180,16 +170,12 @@ impl HelperDspmcProtocol for HelperDspmc { *v_cd = v_cd_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>(); *p_cd = p_cd_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) .collect::>(); t.qps("deserialize_exp", xor_shares_v2.len()); @@ -199,7 +185,7 @@ impl HelperDspmcProtocol for HelperDspmc { _ => { error!("Cannot load xor_shares_v2"); Err(ProtocolError::ErrorDeserialization( - "cannot load xor_shares_v2".to_string(), + "cannot load xor_shares_v2".to_string(), )) } } @@ -214,7 +200,7 @@ impl HelperDspmcProtocol for HelperDspmc { self.features.clone().write(), self.shuffler_gz.clone().write(), ) { - (Ok(mut features), Ok(mut shuffler_gz),) => { + (Ok(mut features), Ok(mut shuffler_gz)) => { let t = timer::Timer::new_silent("set_encrypted_vprime"); features.clear(); @@ -230,7 +216,7 @@ impl HelperDspmcProtocol for HelperDspmc { _ => { error!("Cannot load encrypted_vprime"); Err(ProtocolError::ErrorDeserialization( - "cannot load encrypted_vprime".to_string(), + "cannot load encrypted_vprime".to_string(), )) } } @@ -241,25 +227,18 @@ impl HelperDspmcProtocol for HelperDspmc { v_sd_bytes: TPayload, p_sd_bytes: TPayload, ) -> Result<(), ProtocolError> { - match ( - self.p_sd.clone().write(), - self.v_sd.clone().write(), - ) { + match (self.p_sd.clone().write(), self.v_sd.clone().write()) { (Ok(mut p_sd), Ok(mut v_sd)) => { let t = timer::Timer::new_silent("set set_p_sd_v_sd"); *v_sd = v_sd_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>(); *p_sd = p_sd_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) .collect::>(); t.qps("deserialize_exp", (*p_sd).len()); @@ -269,16 +248,13 @@ impl HelperDspmcProtocol for HelperDspmc { _ => { error!("Cannot load set_p_sd_v_sd"); Err(ProtocolError::ErrorDeserialization( - "cannot load set_p_sd_v_sd".to_string(), + "cannot load set_p_sd_v_sd".to_string(), )) } } } - fn set_u1( - &self, - mut u1: TFeatures, - ) -> Result<(), ProtocolError> { + fn set_u1(&self, mut u1: TFeatures) -> Result<(), ProtocolError> { match ( self.p_sd.clone().read(), self.v_sd.clone().read(), @@ -291,7 +267,7 @@ impl HelperDspmcProtocol for HelperDspmc { xor_shares_v2.clear(); for f_idx in 0..n_features { - permute(p_sd.as_slice(), &mut u1[f_idx]); // p_sc + permute(p_sd.as_slice(), &mut u1[f_idx]); // p_sc let t = u1[f_idx] .iter() .zip_eq(v_sd.iter()) @@ -306,7 +282,7 @@ impl HelperDspmcProtocol for HelperDspmc { _ => { error!("Cannot load set_u1"); Err(ProtocolError::ErrorDeserialization( - "cannot load set_u1".to_string(), + "cannot load set_u1".to_string(), )) } } @@ -328,7 +304,7 @@ impl HelperDspmcProtocol for HelperDspmc { self.enc_company.clone().write(), self.enc_partners.clone().write(), ) { - (Ok(mut enc_company), Ok(mut enc_partners),) => { + (Ok(mut enc_company), Ok(mut enc_partners)) => { let t = timer::Timer::new_silent("set set_encrypted_keys"); // Unflatten and convert to points @@ -348,7 +324,8 @@ impl HelperDspmcProtocol for HelperDspmc { .map(|(s2, s1)| *s2 - *s1) .collect::>(); - ct_psum.get(0..num_ct_keys) + ct_psum + .get(0..num_ct_keys) .unwrap() .iter() .zip_eq(ct_psum.get(1..num_ct_keys + 1).unwrap().iter()) @@ -375,7 +352,7 @@ impl HelperDspmcProtocol for HelperDspmc { _ => { error!("Cannot load set_encrypted_keys"); Err(ProtocolError::ErrorDeserialization( - "cannot load set_encrypted_keys".to_string(), + "cannot load set_encrypted_keys".to_string(), )) } } @@ -402,19 +379,13 @@ impl HelperDspmcProtocol for HelperDspmc { (Ok(enc_partners), Ok(enc_company), Ok(s_partner), Ok(s_company), Ok(mut id_map)) => { // Get the first column. let partner_keys = { - let tmp = enc_partners - .iter() - .map(|s| s[0]) - .collect::>(); + let tmp = enc_partners.iter().map(|s| s[0]).collect::>(); self.ec_cipher.to_bytes(tmp.as_slice()) }; // Get the first column. let company_keys = { - let tmp = enc_company - .iter() - .map(|s| s[0]) - .collect::>(); + let tmp = enc_company.iter().map(|s| s[0]).collect::>(); self.ec_cipher.to_bytes(tmp.as_slice()) }; @@ -437,15 +408,16 @@ impl HelperDspmcProtocol for HelperDspmc { if id_hashmap.contains_key(&key.to_string()) { continue; } - if company_keys_map.contains_key(&key.to_string()) || - !s_partner_map.contains_key(&key.to_string()) { - id_hashmap.insert(key.to_string(), (idx as usize, true)); + if company_keys_map.contains_key(&key.to_string()) + || !s_partner_map.contains_key(&key.to_string()) + { + id_hashmap.insert(key.to_string(), (idx, true)); } } // Add all the remaining keys that company has but the partner doesn't. for (idx, key) in s_company.iter().enumerate() { - id_hashmap.insert(key.to_string(), (idx as usize, false)); + id_hashmap.insert(key.to_string(), (idx, false)); } id_map.clear(); @@ -468,7 +440,7 @@ impl HelperDspmcProtocol for HelperDspmc { self.s_company.clone().write(), self.s_partner.clone().write(), ) { - (Ok(e_company), Ok(mut e_partner), Ok(mut s_company), Ok(mut s_partner),) => { + (Ok(e_company), Ok(mut e_partner), Ok(mut s_company), Ok(mut s_partner)) => { // let t = timer::Timer::new_silent("helper calculate_set_diff"); let s_c = e_company.iter().map(|e| e[0]).collect::>(); @@ -526,7 +498,7 @@ impl HelperDspmcProtocol for HelperDspmc { // if the match occurred not in the first column, // make sure the spine keys will be the same. if idx > 0 { - e[0] = e_company[m_idx][0].clone(); + e[0] = e_company[m_idx][0]; } e_c_valid[m_idx] = false; e_p_match_idx.push(i); @@ -568,7 +540,7 @@ impl HelperDspmcProtocol for HelperDspmc { s_company.clear(); if !t.is_empty() { - s_company.extend(self.ec_cipher.to_bytes(t.as_slice()),); + s_company.extend(self.ec_cipher.to_bytes(t.as_slice())); } // t.qps("s_company", s_company.len()); @@ -619,7 +591,7 @@ impl HelperDspmcProtocol for HelperDspmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); @@ -643,14 +615,15 @@ impl HelperDspmcProtocol for HelperDspmc { self.id_map.clone().read(), self.company_public_key.clone().read(), self.helper_shares.clone().write(), - ) {( - Ok(partner_features), - Ok(xor_shares_v2), - Ok(shuffler_gz), - Ok(id_map), - Ok(company_pk), - Ok(mut shares), - ) => { + ) { + ( + Ok(partner_features), + Ok(xor_shares_v2), + Ok(shuffler_gz), + Ok(id_map), + Ok(company_pk), + Ok(mut shares), + ) => { let t = timer::Timer::new_silent("helper calculate_features_xor_shares"); let mut rng = rand::thread_rng(); let range = Uniform::new(0_u64, u64::MAX); @@ -658,16 +631,16 @@ impl HelperDspmcProtocol for HelperDspmc { let n_features = partner_features.len(); let (t_i, mut g_zi) = { - let z_i = (0..id_map.len()) - .map(|_| gen_scalar()) - .collect::>(); - let x = z_i.iter() + let z_i = (0..id_map.len()).map(|_| gen_scalar()).collect::>(); + let x = z_i + .iter() .map(|a| { - let x = - self.ec_cipher.to_bytes(&vec![a * (*company_pk).0]); + let x = self.ec_cipher.to_bytes(&[a * company_pk.0]); x[0].clone() - }).collect::>(); - let y = z_i.iter() + }) + .collect::>(); + let y = z_i + .iter() .map(|a| a * &RISTRETTO_BASEPOINT_TABLE) .collect::>(); (x, y) @@ -676,30 +649,33 @@ impl HelperDspmcProtocol for HelperDspmc { let mut d_flat = { let mut v_p = Vec::::new(); - let shuffler_gz_points = self.ec_cipher.to_points(&*shuffler_gz); + let shuffler_gz_points = self.ec_cipher.to_points(&shuffler_gz); for f_idx in (0..n_features).rev() { let mask = (0..id_map.len()) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); let t = id_map .iter() .enumerate() .map(|(i, (_, idx, exists))| { - let y = - if *exists { - if f_idx == 0 { - // If exists, overwrite g_z' with g_z from shuffler. - g_zi[i] = shuffler_gz_points[*idx]; - } - // v'' xor v''' xor mask = v'' xor v' xor r xor mask = - // v xor r xor mask - xor_shares_v2[f_idx][*idx] ^ partner_features[f_idx][*idx] ^ mask[i] - } else { - // If it doesn't exist, r xor mask - let y = u64::from_le_bytes((t_i[i].buffer[0..8]).try_into().unwrap()); - y ^ mask[i] - }; + let y = if *exists { + if f_idx == 0 { + // If exists, overwrite g_z' with g_z from shuffler. + g_zi[i] = shuffler_gz_points[*idx]; + } + // v'' xor v''' xor mask = v'' xor v' xor r xor mask = + // v xor r xor mask + xor_shares_v2[f_idx][*idx] + ^ partner_features[f_idx][*idx] + ^ mask[i] + } else { + // If it doesn't exist, r xor mask + let y = u64::from_le_bytes( + (t_i[i].buffer[0..8]).try_into().unwrap(), + ); + y ^ mask[i] + }; ByteBuffer { buffer: y.to_le_bytes().to_vec(), } @@ -721,7 +697,7 @@ impl HelperDspmcProtocol for HelperDspmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); @@ -748,7 +724,7 @@ impl HelperDspmcProtocol for HelperDspmc { .max() .unwrap(); - let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx+1]; + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; for i in 0..id_map.len() { let (_, idx, flag) = id_map[i]; @@ -774,7 +750,7 @@ impl HelperDspmcProtocol for HelperDspmc { .max() .unwrap(); - let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx+1]; + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; for i in 0..id_map.len() { let (_, idx, flag) = id_map[i]; diff --git a/protocol/src/dspmc/mod.rs b/protocol/src/dspmc/mod.rs index 1bad339..0780b37 100644 --- a/protocol/src/dspmc/mod.rs +++ b/protocol/src/dspmc/mod.rs @@ -3,16 +3,16 @@ extern crate csv; -use common::{files, timer}; +use std::collections::HashSet; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use std::sync::RwLock; + +use common::files; +use common::timer; use crypto::prelude::*; -use std::{ - collections::HashSet, - sync::{Arc, RwLock}, -}; - -use std::{error::Error, fmt}; - #[derive(Debug)] pub enum ProtocolError { ErrorDeserialization(String), @@ -40,8 +40,10 @@ fn load_data_keys(plaintext: Arc>>>, path: &str, input_wi if let Ok(mut data) = plaintext.write() { data.clear(); let mut line_it = lines.drain(..); - // Strip the header for now - if input_with_headers && line_it.next().is_some() {} + // Strip the header + if input_with_headers { + line_it.next(); + } let mut t = HashSet::>::new(); // Filter out zero length strings - these will come from ragged @@ -173,7 +175,7 @@ fn serialize_helper(data: Vec>) -> (Vec, TPayload, TPayload) { } pub mod company; -pub mod shuffler; pub mod helper; pub mod partner; +pub mod shuffler; pub mod traits; diff --git a/protocol/src/dspmc/partner.rs b/protocol/src/dspmc/partner.rs index 273126b..58c5392 100644 --- a/protocol/src/dspmc/partner.rs +++ b/protocol/src/dspmc/partner.rs @@ -3,28 +3,24 @@ extern crate csv; -use crypto::{ - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - prelude::*, -}; - -use crate::{ - dspmc::traits::PartnerDspmcProtocol, - shared::TFeatures, -}; -use rand::prelude::*; +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; +use rand::prelude::*; -use std::{ - convert::TryInto, - sync::{Arc, RwLock}, -}; - -use super::{ - load_data_keys, load_data_features, serialize_helper, ProtocolError -}; +use super::load_data_features; +use super::load_data_keys; +use super::serialize_helper; +use super::ProtocolError; +use crate::dspmc::traits::PartnerDspmcProtocol; +use crate::shared::TFeatures; pub struct PartnerDspmc { company_public_key: Arc>, @@ -68,17 +64,20 @@ impl PartnerDspmc { self.plaintext_keys.clone().read().unwrap().len() } - pub fn set_company_public_key(&self, company_public_key: TPayload) -> Result<(), ProtocolError> { + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&company_public_key); // Check that two keys are sent assert_eq!(pk.len(), 2); match self.company_public_key.clone().write() { Ok(mut company_pk) => { - (*company_pk).0 = pk[0]; - (*company_pk).1 = pk[1]; - assert_eq!(((*company_pk).0).is_identity(), false); - assert_eq!(((*company_pk).1).is_identity(), false); + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); Ok(()) } _ => { @@ -108,7 +107,6 @@ impl PartnerDspmc { } } } - } impl Default for PartnerDspmc { @@ -140,22 +138,15 @@ impl PartnerDspmcProtocol for PartnerDspmc { // with EC: company_pk * r let (ct1, pkd_r) = { let r_i = (0..d_flat.len()) - .map(|x| x) .collect::>() .iter() .map(|_| gen_scalar()) .collect::>(); let ct1_bytes = { - let t1 = r_i - .iter() - .map(|x| *x * (*company_pk).0) - .collect::>(); + let t1 = r_i.iter().map(|x| *x * company_pk.0).collect::>(); self.ec_cipher.to_bytes(&t1) }; - let pkd_r = r_i - .iter() - .map(|x| *x * (*helper_pk)) - .collect::>(); + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); (ct1_bytes, pkd_r) }; @@ -165,12 +156,7 @@ impl PartnerDspmcProtocol for PartnerDspmc { .map(|(s, t)| *s + *t) .collect::>(); - ( - self.ec_cipher.to_bytes(ct2.as_slice()), - ct1, - offset, - ) - + (self.ec_cipher.to_bytes(ct2.as_slice()), ct1, offset) }; // Append ct1 @@ -198,7 +184,7 @@ impl PartnerDspmcProtocol for PartnerDspmc { self.helper_public_key.clone().read(), ) { (Ok(pdata), Ok(helper_pk)) => { - let t = timer::Timer::new_silent("get_features_xor_shares"); + let t = timer::Timer::new_silent("get_features_xor_shares"); let n_rows = pdata[0].len(); let n_features = pdata.len(); @@ -206,11 +192,9 @@ impl PartnerDspmcProtocol for PartnerDspmc { // PRG seed = scalar * PK_helper let (seed, ct3) = { let x = gen_scalar(); - let ct3 = self.ec_cipher.to_bytes( - &vec![&x * &RISTRETTO_BASEPOINT_TABLE] - ); + let ct3 = self.ec_cipher.to_bytes(&[&x * &RISTRETTO_BASEPOINT_TABLE]); let seed: [u8; 32] = { - let t = self.ec_cipher.to_bytes(&vec![&x * (*helper_pk)]); + let t = self.ec_cipher.to_bytes(&[x * (*helper_pk)]); t[0].buffer.as_slice().try_into().expect("incorrect length") }; (seed, ct3) @@ -220,7 +204,6 @@ impl PartnerDspmcProtocol for PartnerDspmc { let mut v2 = TFeatures::new(); for _ in 0..n_features { let t = (0..n_rows) - .map(|x| x) .collect::>() .iter() .map(|_| rng.gen::()) @@ -242,7 +225,7 @@ impl PartnerDspmcProtocol for PartnerDspmc { buffer: z.to_le_bytes().to_vec(), } }) - .collect::>(); + .collect::>(); v_p.push(t); } @@ -255,7 +238,7 @@ impl PartnerDspmcProtocol for PartnerDspmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); d_flat.extend(ct3); diff --git a/protocol/src/dspmc/shuffler.rs b/protocol/src/dspmc/shuffler.rs index 613f278..a8cc0a1 100644 --- a/protocol/src/dspmc/shuffler.rs +++ b/protocol/src/dspmc/shuffler.rs @@ -3,35 +3,25 @@ extern crate csv; -use crypto::{ - eccipher::{ECCipher, ECRistrettoParallel, gen_scalar}, - prelude::*, -}; +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; use itertools::Itertools; +use rand::distributions::Uniform; +use rand::Rng; -use crate::{ - dspmc::traits::ShufflerDspmcProtocol, - shared::TFeatures, -}; - -use common::{ - permutations::{gen_permute_pattern, permute}, - timer, -}; - -use rand::{ - distributions::Uniform, - Rng, -}; - -use std::{ - convert::TryInto, - sync::{Arc, RwLock}, -}; - -use super::{ - serialize_helper, ProtocolError -}; +use super::serialize_helper; +use super::ProtocolError; +use crate::dspmc::traits::ShufflerDspmcProtocol; +use crate::shared::TFeatures; pub struct ShufflerDspmc { company_public_key: Arc>, @@ -39,9 +29,9 @@ pub struct ShufflerDspmc { ec_cipher: ECRistrettoParallel, p_cs: Arc>>, v_cs: Arc>>, - perms: Arc, Vec)>>, // (p_sc, p_sd) - blinds: Arc, Vec)>>, // (v_sc, v_sd) - xor_shares_v1: Arc>, // v' + perms: Arc, Vec)>>, // (p_sc, p_sd) + blinds: Arc, Vec)>>, // (v_sc, v_sd) + xor_shares_v1: Arc>, // v' ct1_dprime: Arc>>>, ct2_dprime: Arc>>>, } @@ -62,8 +52,9 @@ impl ShufflerDspmc { } } - pub fn set_company_public_key(&self, - company_public_key: TPayload + pub fn set_company_public_key( + &self, + company_public_key: TPayload, ) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&company_public_key); // Check that two keys are sent @@ -71,10 +62,10 @@ impl ShufflerDspmc { match self.company_public_key.clone().write() { Ok(mut company_pk) => { - (*company_pk).0 = pk[0]; - (*company_pk).1 = pk[1]; - assert_eq!(((*company_pk).0).is_identity(), false); - assert_eq!(((*company_pk).1).is_identity(), false); + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); Ok(()) } _ => { @@ -86,9 +77,7 @@ impl ShufflerDspmc { } } - pub fn set_helper_public_key(&self, - helper_public_key: TPayload - ) -> Result<(), ProtocolError> { + pub fn set_helper_public_key(&self, helper_public_key: TPayload) -> Result<(), ProtocolError> { let pk = self.ec_cipher.to_points(&helper_public_key); // Check that one key is sent assert_eq!(pk.len(), 1); @@ -107,7 +96,6 @@ impl ShufflerDspmc { } } } - } impl Default for ShufflerDspmc { @@ -117,27 +105,22 @@ impl Default for ShufflerDspmc { } impl ShufflerDspmcProtocol for ShufflerDspmc { - fn set_p_cs_v_cs( &self, v_cs_bytes: TPayload, p_cs_bytes: TPayload, ) -> Result<(), ProtocolError> { - match (self.p_cs.clone().write(), self.v_cs.clone().write(),) { + match (self.p_cs.clone().write(), self.v_cs.clone().write()) { (Ok(mut p_cs), Ok(mut v_cs)) => { let t = timer::Timer::new_silent("set p_cs, v_cs"); *v_cs = v_cs_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>(); *p_cs = p_cs_bytes .iter() - .map(|x| { - u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize - }) + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) .collect::>(); t.qps("deserialize_exp", p_cs_bytes.len()); @@ -146,7 +129,7 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { _ => { error!("Cannot load p_cs, v_cs"); Err(ProtocolError::ErrorDeserialization( - "cannot load p_cs, v_cs".to_string(), + "cannot load p_cs, v_cs".to_string(), )) } } @@ -170,36 +153,44 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { perms.1.extend(gen_permute_pattern(data_len)); blinds.0 = (0..data_len) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); blinds.1 = (0..data_len) - .map(|_| rng.sample(&range)) + .map(|_| rng.sample(range)) .collect::>(); - let mut p_sc_v_sc = perms.0.iter() + let mut p_sc_v_sc = perms + .0 + .iter() .map(|e| ByteBuffer { - buffer: (*e as u64).to_le_bytes().to_vec(), + buffer: (*e).to_le_bytes().to_vec(), }) .collect::>(); - let v_sc_bytes = blinds.0.iter() + let v_sc_bytes = blinds + .0 + .iter() .map(|e| ByteBuffer { - buffer: (*e as u64).to_le_bytes().to_vec(), + buffer: (*e).to_le_bytes().to_vec(), }) .collect::>(); p_sc_v_sc.extend(v_sc_bytes); - let mut p_sd_v_sd = perms.1.iter() + let mut p_sd_v_sd = perms + .1 + .iter() .map(|e| ByteBuffer { buffer: (*e as u64).to_le_bytes().to_vec(), }) .collect::>(); - let v_sd_bytes = blinds.1.iter() + let v_sd_bytes = blinds + .1 + .iter() .map(|e| ByteBuffer { buffer: (*e as u64).to_le_bytes().to_vec(), }) .collect::>(); p_sd_v_sd.extend(v_sd_bytes); - p_sd_v_sd.push(ByteBuffer{ + p_sd_v_sd.push(ByteBuffer { buffer: (data_len as u64).to_le_bytes().to_vec(), }); @@ -232,8 +223,17 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { self.ct1_dprime.clone().write(), self.ct2_dprime.clone().write(), ) { - (Ok(p_cs), Ok(v_cs), Ok(perms), Ok(blinds), Ok(company_pk), - Ok(helper_pk), Ok(mut v_p), Ok(mut ct1_dprime), Ok(mut ct2_dprime)) => { + ( + Ok(p_cs), + Ok(v_cs), + Ok(perms), + Ok(blinds), + Ok(company_pk), + Ok(helper_pk), + Ok(mut v_p), + Ok(mut ct1_dprime), + Ok(mut ct2_dprime), + ) => { // This is an array of exclusive-inclusive prefix sum - hence // number of keys is one less than length let num_keys = psum.len() - 1; @@ -266,49 +266,42 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { // Compute p_sd(p_sc(p_cs( u2 ) xor v_cs) xor v_sc) xor v_sd (*v_p).clear(); for f_idx in (0..n_features).rev() { - permute(p_cs.as_slice(), &mut u2[f_idx]); // p_cs + permute(p_cs.as_slice(), &mut u2[f_idx]); // p_cs let mut x_2 = u2[f_idx] .iter() .zip_eq(v_cs.iter()) .map(|(s, v_cs)| *s ^ *v_cs) .collect::>(); - permute(perms.0.as_slice(), &mut x_2); // p_sc + permute(perms.0.as_slice(), &mut x_2); // p_sc let mut t_1 = x_2 .iter() .zip_eq(blinds.0.iter()) .map(|(s, v_sc)| *s ^ *v_sc) .collect::>(); - permute(perms.1.as_slice(), &mut t_1); // p_sd - (*v_p).push(t_1 - .iter() - .zip_eq(blinds.1.iter()) - .map(|(s, v_sd)| *s ^ *v_sd) - .collect::>() + permute(perms.1.as_slice(), &mut t_1); // p_sd + (*v_p).push( + t_1.iter() + .zip_eq(blinds.1.iter()) + .map(|(s, v_sd)| *s ^ *v_sd) + .collect::>(), ); } // Re-randomize ct1'' and ct2'' and flatten let (mut ct1_dprime_flat, ct2_dprime_flat, ct_offset) = { let r_i = (0..n_rows) - .map(|x| x) .collect::>() .iter() .map(|_| gen_scalar()) .collect::>(); // company_pk^r // with EC: company_pk * r - let pkc_r = r_i - .iter() - .map(|x| *x * ((*company_pk).0)) - .collect::>(); + let pkc_r = r_i.iter().map(|x| *x * (company_pk.0)).collect::>(); // helper_pk^r // with EC: helper_pk * r - let pkd_r = r_i - .iter() - .map(|x| *x * (*helper_pk)) - .collect::>(); + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); permute(perms.0.as_slice(), &mut ct1_prime); // p_sc permute(perms.0.as_slice(), &mut ct2_prime); // p_sc @@ -320,18 +313,14 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { *ct1_dprime = ct1_prime .iter() .zip_eq(pkc_r.iter()) - .map(|(s, t)| { - (*s).iter().map(|si| *si + *t).collect::>() - }) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) .collect::>(); // ct2' = ct2'' * helper_pk^r // with EC: ct2' = ct2'' + helper_pk*r *ct2_dprime = ct2_prime .iter() .zip_eq(pkd_r.iter()) - .map(|(s, t)| { - (*s).iter().map(|si| *si + *t).collect::>() - }) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) .collect::>(); let (ct1_dprime_flat, _ct1_offset) = { let (d_flat, mut offset, metadata) = serialize_helper(ct1_dprime.clone()); @@ -367,27 +356,21 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { self.company_public_key.clone().read(), self.xor_shares_v1.clone().read(), ) { - (Ok(company_pk), Ok(v_p),) => { + (Ok(company_pk), Ok(v_p)) => { let t = timer::Timer::new_silent("get_blinded_vprime"); let n_rows = v_p[0].len(); let n_features = v_p.len(); - let z_i = (0..n_rows) - .map(|_| gen_scalar()) - .collect::>(); + let z_i = (0..n_rows).map(|_| gen_scalar()).collect::>(); let mut d_flat = { let r_i = { let y_zi = { - let t = z_i - .iter() - .map(|x| *x * (*company_pk).0) - .collect::>(); + let t = z_i.iter().map(|x| *x * company_pk.0).collect::>(); self.ec_cipher.to_bytes(&t) }; - y_zi - .iter() + y_zi.iter() .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) .collect::>() }; @@ -396,7 +379,6 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { for f_idx in (0..n_features).rev() { let t = (0..n_rows) - .map(|x| x) .collect::>() .iter() .map(|i| { @@ -429,7 +411,7 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { }, ByteBuffer { buffer: (n_features as u64).to_le_bytes().to_vec(), - } + }, ]; d_flat.extend(metadata); @@ -444,5 +426,4 @@ impl ShufflerDspmcProtocol for ShufflerDspmc { } } } - } diff --git a/protocol/src/dspmc/traits.rs b/protocol/src/dspmc/traits.rs index cc761cc..2b12ef3 100644 --- a/protocol/src/dspmc/traits.rs +++ b/protocol/src/dspmc/traits.rs @@ -3,19 +3,22 @@ extern crate crypto; -use crate::{ - dspmc::ProtocolError, - shared::TFeatures, -}; use crypto::prelude::TPayload; +use crate::dspmc::ProtocolError; +use crate::shared::TFeatures; + pub trait PartnerDspmcProtocol { fn get_encrypted_keys(&self) -> Result; fn get_features_xor_shares(&self) -> Result; } pub trait ShufflerDspmcProtocol { - fn set_p_cs_v_cs(&self, v_cs_bytes: TPayload, p_cs_bytes: TPayload) -> Result<(), ProtocolError>; + fn set_p_cs_v_cs( + &self, + v_cs_bytes: TPayload, + p_cs_bytes: TPayload, + ) -> Result<(), ProtocolError>; fn gen_permutations(&self) -> Result<(TPayload, TPayload), ProtocolError>; fn get_blinded_vprime(&self) -> Result; fn compute_v2prime_ct1ct2( @@ -41,7 +44,11 @@ pub trait CompanyDspmcProtocol { fn get_ct1_ct2(&self) -> Result; fn get_p_cs_v_cs(&self) -> Result; fn get_u1(&self) -> Result; - fn calculate_features_xor_shares(&self, features: TFeatures, g_zi: TPayload) -> Result<(), ProtocolError>; + fn calculate_features_xor_shares( + &self, + features: TFeatures, + g_zi: TPayload, + ) -> Result<(), ProtocolError>; fn write_company_to_id_map(&self) -> Result<(), ProtocolError>; fn print_id_map(&self); fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 563a00d..dae101d 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -8,12 +8,12 @@ extern crate log; pub mod cross_psi; pub mod cross_psi_xor; +pub mod dpmc; +pub mod dspmc; pub mod fileio; pub mod pjc; pub mod private_id; pub mod private_id_multi_key; -pub mod dpmc; -pub mod dspmc; pub mod suid_create; pub mod shared {