diff --git a/fbpcf/frontend/BitString.h b/fbpcf/frontend/BitString.h index 02e05f77..695b08b2 100644 --- a/fbpcf/frontend/BitString.h +++ b/fbpcf/frontend/BitString.h @@ -121,6 +121,13 @@ class BitString : public scheduler::SchedulerKeeper { mux(const Bit& choice, const BitString& other) const; + BitString batchingWith( + const std::vector>& others) + const; + + std::vector> unbatching( + std::shared_ptr> unbatchingStrategy) const; + private: std::vector> data_; diff --git a/fbpcf/frontend/BitString_impl.h b/fbpcf/frontend/BitString_impl.h index 80cf3bbd..3b29de52 100644 --- a/fbpcf/frontend/BitString_impl.h +++ b/fbpcf/frontend/BitString_impl.h @@ -160,4 +160,50 @@ BitString::mux( return rst; } +template +BitString +BitString::batchingWith( + const std::vector>& others) + const { + static_assert(usingBatch, "Only batch values needs to rebatch!"); + + for (auto& item : others) { + if (item.data_.size() != data_.size()) { + throw std::runtime_error( + "The BitStrings need to have the same length to batch together."); + } + } + + BitString rst(data_.size()); + size_t batchSize = others.size(); + std::vector> bits(batchSize); + for (size_t i = 0; i < data_.size(); i++) { + for (size_t j = 0; j < batchSize; j++) { + bits[j] = others.at(j).data_.at(i); + } + rst.data_[i] = data_.at(i).batchingWith(bits); + } + return rst; +} + +template +std::vector> +BitString::unbatching( + std::shared_ptr> unbatchingStrategy) const { + static_assert(usingBatch, "Only batch values needs to rebatch!"); + std::vector> rst( + unbatchingStrategy->size()); + for (auto& item : rst) { + item.resize(data_.size()); + } + + for (size_t i = 0; i < data_.size(); i++) { + auto bitVec = data_.at(i).unbatching(unbatchingStrategy); + for (size_t j = 0; j < unbatchingStrategy->size(); j++) { + rst.at(j).data_.at(i) = bitVec.at(j); + } + } + return rst; +} + } // namespace fbpcf::frontend diff --git a/fbpcf/frontend/test/BitStringTest.cpp b/fbpcf/frontend/test/BitStringTest.cpp index 78726eba..884b6638 100644 --- a/fbpcf/frontend/test/BitStringTest.cpp +++ b/fbpcf/frontend/test/BitStringTest.cpp @@ -16,6 +16,7 @@ #include "fbpcf/test/TestHelper.h" namespace fbpcf::frontend { + TEST(StringTest, testInputAndOutput) { std::random_device rd; std::mt19937_64 e(rd()); @@ -453,4 +454,73 @@ TEST(StringTest, testResizeWithAND) { } } +TEST(StringTest, testRebatching) { + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution dSize(1, 1024); + + std::uniform_int_distribution dBool(0, 1); + + scheduler::SchedulerKeeper<0>::setScheduler( + std::make_unique( + scheduler::WireKeeper::createWithUnorderedMap())); + + using SecBatchString = BitString; + + std::vector testValue(dSize(e)); + for (size_t i = 0; i < testValue.size(); i++) { + testValue[i] = dBool(e); + } + uint32_t length = dSize(e); + uint32_t batchSize1 = dSize(e); + uint32_t batchSize2 = dSize(e); + uint32_t batchSize3 = dSize(e); + std::vector> testBatchValue1( + length, std::vector(batchSize1)); + std::vector> testBatchValue2( + length, std::vector(batchSize2)); + std::vector> testBatchValue3( + length, std::vector(batchSize3)); + + for (size_t i = 0; i < length; i++) { + for (size_t j = 0; j < batchSize1; j++) { + testBatchValue1[i][j] = dBool(e); + } + for (size_t j = 0; j < batchSize2; j++) { + testBatchValue2[i][j] = dBool(e); + } + for (size_t j = 0; j < batchSize3; j++) { + testBatchValue3[i][j] = dBool(e); + } + } + + SecBatchString v1(testBatchValue1, 0); + SecBatchString v2(testBatchValue2, 0); + SecBatchString v3(testBatchValue3, 0); + + auto v4 = v1.batchingWith({v2, v3}); + auto v123 = v4.unbatching(std::make_shared>( + std::vector({batchSize1, batchSize2, batchSize3}))); + + auto t4 = v4.openToParty(0).getValue(); + auto t5 = v123.at(0).openToParty(0).getValue(); + auto t6 = v123.at(1).openToParty(0).getValue(); + auto t7 = v123.at(2).openToParty(0).getValue(); + + for (size_t i = 0; i < length; i++) { + testVectorEq(t5.at(i), testBatchValue1.at(i)); + testVectorEq(t6.at(i), testBatchValue2.at(i)); + testVectorEq(t7.at(i), testBatchValue3.at(i)); + testBatchValue1[i].insert( + testBatchValue1[i].end(), + testBatchValue2[i].begin(), + testBatchValue2[i].end()); + testBatchValue1[i].insert( + testBatchValue1[i].end(), + testBatchValue3[i].begin(), + testBatchValue3[i].end()); + testVectorEq(t4.at(i), testBatchValue1.at(i)); + } +} + } // namespace fbpcf::frontend