Skip to content

Commit

Permalink
Add rebatching API for bitstring (#117)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #117

For UDP protocol, we need to add a new type of gate. Rebatching gate. This type of gate allows to break a batch of values into smaller batches or combine several batches into a larger one.

This diff adds the APIs for bit string type to batching/unbatching

Reviewed By: elliottlawrence

Differential Revision: D34914178

fbshipit-source-id: 411e1a60f2afdcfd72bb89a41e809ea9712781e2
  • Loading branch information
RuiyuZhu authored and facebook-github-bot committed Mar 21, 2022
1 parent 350f616 commit 0f006af
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fbpcf/frontend/BitString.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ class BitString : public scheduler::SchedulerKeeper<schedulerId> {
mux(const Bit<isSecretChoice, schedulerId, usingBatch>& choice,
const BitString<isSecretOther, schedulerId, usingBatch>& other) const;

BitString<isSecret, schedulerId, usingBatch> batchingWith(
const std::vector<BitString<isSecret, schedulerId, usingBatch>>& others)
const;

std::vector<BitString<isSecret, schedulerId, usingBatch>> unbatching(
std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const;

private:
std::vector<Bit<isSecret, schedulerId, usingBatch>> data_;

Expand Down
46 changes: 46 additions & 0 deletions fbpcf/frontend/BitString_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,50 @@ BitString<isSecret, schedulerId, usingBatch>::mux(
return rst;
}

template <bool isSecret, int schedulerId, bool usingBatch>
BitString<isSecret, schedulerId, usingBatch>
BitString<isSecret, schedulerId, usingBatch>::batchingWith(
const std::vector<BitString<isSecret, schedulerId, usingBatch>>& others)
const {
static_assert(usingBatch, "Only batch values needs to rebatch!");

for (auto& item : others) {
if (item.data_.size() != data_.size()) {
throw std::runtime_error(
"The BitStrings need to have the same length to batch together.");
}
}

BitString<isSecret, schedulerId, usingBatch> rst(data_.size());
size_t batchSize = others.size();
std::vector<Bit<true, schedulerId, usingBatch>> bits(batchSize);
for (size_t i = 0; i < data_.size(); i++) {
for (size_t j = 0; j < batchSize; j++) {
bits[j] = others.at(j).data_.at(i);
}
rst.data_[i] = data_.at(i).batchingWith(bits);
}
return rst;
}

template <bool isSecret, int schedulerId, bool usingBatch>
std::vector<BitString<isSecret, schedulerId, usingBatch>>
BitString<isSecret, schedulerId, usingBatch>::unbatching(
std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const {
static_assert(usingBatch, "Only batch values needs to rebatch!");
std::vector<BitString<isSecret, schedulerId, usingBatch>> rst(
unbatchingStrategy->size());
for (auto& item : rst) {
item.resize(data_.size());
}

for (size_t i = 0; i < data_.size(); i++) {
auto bitVec = data_.at(i).unbatching(unbatchingStrategy);
for (size_t j = 0; j < unbatchingStrategy->size(); j++) {
rst.at(j).data_.at(i) = bitVec.at(j);
}
}
return rst;
}

} // namespace fbpcf::frontend
70 changes: 70 additions & 0 deletions fbpcf/frontend/test/BitStringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "fbpcf/test/TestHelper.h"

namespace fbpcf::frontend {

TEST(StringTest, testInputAndOutput) {
std::random_device rd;
std::mt19937_64 e(rd());
Expand Down Expand Up @@ -453,4 +454,73 @@ TEST(StringTest, testResizeWithAND) {
}
}

TEST(StringTest, testRebatching) {
std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<uint32_t> dSize(1, 1024);

std::uniform_int_distribution<uint8_t> dBool(0, 1);

scheduler::SchedulerKeeper<0>::setScheduler(
std::make_unique<scheduler::PlaintextScheduler>(
scheduler::WireKeeper::createWithUnorderedMap()));

using SecBatchString = BitString<true, 0, true>;

std::vector<bool> testValue(dSize(e));
for (size_t i = 0; i < testValue.size(); i++) {
testValue[i] = dBool(e);
}
uint32_t length = dSize(e);
uint32_t batchSize1 = dSize(e);
uint32_t batchSize2 = dSize(e);
uint32_t batchSize3 = dSize(e);
std::vector<std::vector<bool>> testBatchValue1(
length, std::vector<bool>(batchSize1));
std::vector<std::vector<bool>> testBatchValue2(
length, std::vector<bool>(batchSize2));
std::vector<std::vector<bool>> testBatchValue3(
length, std::vector<bool>(batchSize3));

for (size_t i = 0; i < length; i++) {
for (size_t j = 0; j < batchSize1; j++) {
testBatchValue1[i][j] = dBool(e);
}
for (size_t j = 0; j < batchSize2; j++) {
testBatchValue2[i][j] = dBool(e);
}
for (size_t j = 0; j < batchSize3; j++) {
testBatchValue3[i][j] = dBool(e);
}
}

SecBatchString v1(testBatchValue1, 0);
SecBatchString v2(testBatchValue2, 0);
SecBatchString v3(testBatchValue3, 0);

auto v4 = v1.batchingWith({v2, v3});
auto v123 = v4.unbatching(std::make_shared<std::vector<uint32_t>>(
std::vector<uint32_t>({batchSize1, batchSize2, batchSize3})));

auto t4 = v4.openToParty(0).getValue();
auto t5 = v123.at(0).openToParty(0).getValue();
auto t6 = v123.at(1).openToParty(0).getValue();
auto t7 = v123.at(2).openToParty(0).getValue();

for (size_t i = 0; i < length; i++) {
testVectorEq(t5.at(i), testBatchValue1.at(i));
testVectorEq(t6.at(i), testBatchValue2.at(i));
testVectorEq(t7.at(i), testBatchValue3.at(i));
testBatchValue1[i].insert(
testBatchValue1[i].end(),
testBatchValue2[i].begin(),
testBatchValue2[i].end());
testBatchValue1[i].insert(
testBatchValue1[i].end(),
testBatchValue3[i].begin(),
testBatchValue3[i].end());
testVectorEq(t4.at(i), testBatchValue1.at(i));
}
}

} // namespace fbpcf::frontend

0 comments on commit 0f006af

Please sign in to comment.