Skip to content

Commit

Permalink
Implement the max function for integer type and its unit tests (#179)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #179

Implement the max function for integer type and its unit tests

Reviewed By: RuiyuZhu

Differential Revision: D35390074

fbshipit-source-id: 36aa50fce720314c2d645c1e91fafc9eeee46c43
  • Loading branch information
adelesun authored and facebook-github-bot committed Apr 6, 2022
1 parent 6d623e0 commit 8fb6035
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
11 changes: 11 additions & 0 deletions fbpcf/frontend/Int.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ using Integer =
typename IntTypeHelper<typename T::type, IsSecret<T>::value, schedulerId>::
type;

template <
bool isSigned,
int8_t width,
bool isSecret1,
bool isSecret2,
int schedulerId,
bool usingBatch>
Int<isSigned, width, isSecret1 || isSecret2, schedulerId, usingBatch> max(
const Int<isSigned, width, isSecret1, schedulerId, usingBatch>& left,
const Int<isSigned, width, isSecret2, schedulerId, usingBatch>& right);

} // namespace fbpcf::frontend

#include "fbpcf/frontend/Int_impl.h"
13 changes: 13 additions & 0 deletions fbpcf/frontend/Int_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,4 +604,17 @@ Int<isSigned, width, isSecret, schedulerId, usingBatch>::unbatching(
return rst;
}

template <
bool isSigned,
int8_t width,
bool isSecret1,
bool isSecret2,
int schedulerId,
bool usingBatch>
Int<isSigned, width, isSecret1 || isSecret2, schedulerId, usingBatch> max(
const Int<isSigned, width, isSecret1, schedulerId, usingBatch>& left,
const Int<isSigned, width, isSecret2, schedulerId, usingBatch>& right) {
return left.mux(left < right, right);
}

} // namespace fbpcf::frontend
144 changes: 144 additions & 0 deletions fbpcf/frontend/test/IntTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,4 +1434,148 @@ TEST(IntTest, testRebatch) {
testVectorEq(int123.at(2).openToParty(partyId).getValue(), v3);
}

TEST(IntTest, testMax) {
const int8_t width = 64;

int64_t largestSigned = std::numeric_limits<int64_t>().max();
int64_t smallestSigned = std::numeric_limits<int64_t>().min();
uint64_t largestUnsigned = std::numeric_limits<uint64_t>().max();

scheduler::SchedulerKeeper<0>::setScheduler(
std::make_unique<scheduler::PlaintextScheduler>(
scheduler::WireKeeper::createWithUnorderedMap()));
using secSignedInt = Integer<Secret<Signed<width>>, 0>;
using pubSignedInt = Integer<Public<Signed<width>>, 0>;
using secUnsignedInt = Integer<Secret<Unsigned<width>>, 0>;
using pubUnsignedInt = Integer<Public<Unsigned<width>>, 0>;

int partyId = 2;

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<int64_t> dist1(smallestSigned, largestSigned);

std::uniform_int_distribution<uint64_t> dist2(0, largestUnsigned);

for (int i = 0; i < 100; i++) {
int64_t v1 = dist1(e);
int64_t v2 = dist1(e);

secSignedInt int1(v1, partyId);
secSignedInt int2(v2, partyId);
pubSignedInt int3(v1);
pubSignedInt int4(v2);

auto r1 = max(int1, int2);
auto r2 = max(int1, int4);
auto r3 = max(int3, int4);

auto expectedValue = v1 < v2 ? v2 : v1;

EXPECT_EQ(r1.openToParty(partyId).getValue(), expectedValue);
EXPECT_EQ(r2.openToParty(partyId).getValue(), expectedValue);
EXPECT_EQ(r3.getValue(), expectedValue);
}

for (int i = 0; i < 100; i++) {
uint64_t v1 = dist2(e);
uint64_t v2 = dist2(e);

secUnsignedInt int1(v1, partyId);
secUnsignedInt int2(v2, partyId);
pubUnsignedInt int3(v1);
pubUnsignedInt int4(v2);

auto r1 = max(int1, int2);
auto r2 = max(int1, int4);
auto r3 = max(int3, int4);

auto expectedValue = v1 < v2 ? v2 : v1;

EXPECT_EQ(r1.openToParty(partyId).getValue(), expectedValue);
EXPECT_EQ(r2.openToParty(partyId).getValue(), expectedValue);
EXPECT_EQ(r3.getValue(), expectedValue);
}
}

TEST(IntTest, testMaxBatch) {
const int8_t width = 64;

int64_t largestSigned = std::numeric_limits<int64_t>().max();
int64_t smallestSigned = std::numeric_limits<int64_t>().min();
uint64_t largestUnsigned = std::numeric_limits<uint64_t>().max();

scheduler::SchedulerKeeper<0>::setScheduler(
std::make_unique<scheduler::PlaintextScheduler>(
scheduler::WireKeeper::createWithUnorderedMap()));
using secSignedIntBatch = Integer<Secret<Batch<Signed<width>>>, 0>;
using pubSignedIntBatch = Integer<Public<Batch<Signed<width>>>, 0>;
using secUnsignedIntBatch = Integer<Secret<Batch<Unsigned<width>>>, 0>;
using pubUnsignedIntBatch = Integer<Public<Batch<Unsigned<width>>>, 0>;

size_t batchSize = 9;

int partyId = 2;

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<int64_t> dist1(smallestSigned, largestSigned);

std::uniform_int_distribution<uint64_t> dist2(0, largestUnsigned);

for (int i = 0; i < 100; i++) {
std::vector<int64_t> v1(batchSize);
std::vector<int64_t> v2(batchSize);
for (size_t j = 0; j < batchSize; j++) {
v1[j] = dist1(e);
v2[j] = dist1(e);
}

secSignedIntBatch int1(v1, partyId);
secSignedIntBatch int2(v2, partyId);
pubSignedIntBatch int3(v1);
pubSignedIntBatch int4(v2);

auto r1 = max(int1, int2);
auto r2 = max(int1, int4);
auto r3 = max(int3, int4);

std::vector<int64_t> expectedValue(batchSize);
for (size_t j = 0; j < batchSize; j++) {
expectedValue[j] = v1[j] < v2[j] ? v2[j] : v1[j];
}

testVectorEq(r1.openToParty(partyId).getValue(), expectedValue);
testVectorEq(r2.openToParty(partyId).getValue(), expectedValue);
testVectorEq(r3.getValue(), expectedValue);
}

for (int i = 0; i < 100; i++) {
std::vector<uint64_t> v1(batchSize);
std::vector<uint64_t> v2(batchSize);
for (size_t j = 0; j < batchSize; j++) {
v1[j] = dist2(e);
v2[j] = dist2(e);
}

secUnsignedIntBatch int1(v1, partyId);
secUnsignedIntBatch int2(v2, partyId);
pubUnsignedIntBatch int3(v1);
pubUnsignedIntBatch int4(v2);

auto r1 = max(int1, int2);
auto r2 = max(int1, int4);
auto r3 = max(int3, int4);

std::vector<uint64_t> expectedValue(batchSize);
for (size_t j = 0; j < batchSize; j++) {
expectedValue[j] = v1[j] < v2[j] ? v2[j] : v1[j];
}

testVectorEq(r1.openToParty(partyId).getValue(), expectedValue);
testVectorEq(r2.openToParty(partyId).getValue(), expectedValue);
testVectorEq(r3.getValue(), expectedValue);
}
}

} // namespace fbpcf::frontend

0 comments on commit 8fb6035

Please sign in to comment.