From 2f51aa1468cca73a81494f4599e3464d4e6bdb96 Mon Sep 17 00:00:00 2001 From: Tal Davidi Date: Thu, 23 Feb 2023 10:18:06 -0800 Subject: [PATCH] Add packed bit field support to Row Structure (#493) Summary: Pull Request resolved: https://github.com/facebookresearch/fbpcf/pull/493 # Background: Currently in order to successfully use UDP, you must write some carefully crafted code that will take all the rows of metadata for one side and package it into a collection of bytes. Afterwards the caller will get a `SecString` object back which is a bit representation of all the bytes they passed in, minus the filtered out rows. The user must then extract the corresponding bits for each column into separate MPC Types. This is a cumbersome process which is error prone, as you must make sure to carefully match up the two steps and any changes can cause a bug. # This Diff Adds support to RowDefinition for packed bits. Right now you directly pass in the column names that are to be packed, and the interface will combine each column. The output will contain columns in a vector form. I will refactor that in a future diff to make a bit easier to use. Reviewed By: haochenuw Differential Revision: D43366173 fbshipit-source-id: 6604bad668dbd4acaf7e59d3b60dfc0be217a091 --- .../serialization/RowStructureDefinition.h | 54 +++++++++++++++++++ .../test/RowSerializationTest.cpp | 46 ++++++++++++++-- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h index 00b99597..cfef79d9 100644 --- a/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h @@ -79,6 +79,11 @@ class RowStructureDefinition : public IRowStructureDefinition { auto columnType = columnDefinition->getColumnType(); switch (columnType) { + case IColumnDefinition< + schedulerId>::SupportedColumnTypes::PackedBitField: + serializePackedBitFieldColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; case IColumnDefinition::SupportedColumnTypes::UInt32: serializeIntegerColumn( columnPointer, data, writeBuffers, numRows, byteOffset); @@ -179,6 +184,55 @@ class RowStructureDefinition : public IRowStructureDefinition { return result; } + void serializePackedBitFieldColumn( + const IColumnDefinition* columnPointer, + const std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType>& + inputData, + std::vector>& writeBuffers, + int numRows, + size_t byteOffset) const { + const PackedBitFieldColumn* packedBitCol = + dynamic_cast*>(columnPointer); + + if (packedBitCol == nullptr) { + throw std::runtime_error("Failed to cast to PackedBitFieldColumn"); + } + std::vector> bitPack( + numRows, std::vector(packedBitCol->getSubColumnNames().size())); + + for (int i = 0; i < packedBitCol->getSubColumnNames().size(); i++) { + std::string colName = packedBitCol->getSubColumnNames()[i]; + if (!inputData.contains(colName)) { + throw std::runtime_error( + "Column: " + colName + + " which was defined in the structure was not included in the input data map."); + } + + const std::vector bitVals = + std::get>(inputData.at(colName)); + + if (bitVals.size() != numRows) { + std::string err = folly::sformat( + "Invalid number of values for column {} .Got {} values but number of rows should be {} ", + colName, + bitVals.size(), + numRows); + throw std::runtime_error(err); + } + + for (int j = 0; j < numRows; j++) { + bitPack[j][i] = bitVals[j]; + } + } + + for (int i = 0; i < numRows; i++) { + packedBitCol->serializeColumnAsPlaintextBytes( + bitPack.data() + i, writeBuffers[i].data() + byteOffset); + } + } + template void serializeIntegerColumn( const IColumnDefinition* columnPointer, diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp index 4c69f831..358db708 100644 --- a/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp @@ -14,8 +14,10 @@ #include "fbpcf/scheduler/SchedulerHelper.h" #include "fbpcf/test/TestHelper.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/FixedSizeArrayColumn.h" #include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h" #include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/PackedBitFieldColumn.h" #include "fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h" namespace fbpcf::mpc_std_lib::unified_data_process::serialization { @@ -31,6 +33,12 @@ std::unique_ptr> createRowDefinition() { std::make_unique>("int64Column")); columnDefs->push_back( std::make_unique>("uint32Column")); + + std::vector bitColumnNames = { + "boolColumn1", "boolColumn2", "boolColumn3", "boolColumn4"}; + columnDefs->push_back(std::make_unique>( + "packedBits", bitColumnNames)); + auto serializer = std::make_unique>( std::move(columnDefs)); return std::move(serializer); @@ -112,6 +120,16 @@ deserializeAndRevealAllColumns( [](uint64_t data) { return data; }); rst.emplace("uint32Column", uint32Data); + std::vector::SecBool> + packedBitsMPCValue = std::get< + std::vector::SecBool>>( + deserialization.at("packedBits")); + + rst.emplace("boolColumn1", packedBitsMPCValue[0].openToParty(0).getValue()); + rst.emplace("boolColumn2", packedBitsMPCValue[1].openToParty(0).getValue()); + rst.emplace("boolColumn3", packedBitsMPCValue[2].openToParty(0).getValue()); + rst.emplace("boolColumn4", packedBitsMPCValue[3].openToParty(0).getValue()); + return rst; } @@ -130,6 +148,7 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) { std::random_device rd; std::mt19937_64 e(rd()); + std::uniform_int_distribution<> boolDist(0, 1); std::uniform_int_distribution uint32Dist( std::numeric_limits().min(), std::numeric_limits().max()); @@ -145,17 +164,26 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) { std::unique_ptr> serializer1 = createRowDefinition<1>(); - EXPECT_EQ(serializer0->getRowSizeBytes(), 16); - EXPECT_EQ(serializer1->getRowSizeBytes(), 16); + EXPECT_EQ(serializer0->getRowSizeBytes(), 17); + EXPECT_EQ(serializer1->getRowSizeBytes(), 17); std::vector int32Data(0); std::vector int64Data(0); std::vector uint32Data(0); + std::vector boolData1(0); + std::vector boolData2(0); + std::vector boolData3(0); + std::vector boolData4(0); for (int i = 0; i < batchSize; i++) { int32Data.push_back(int32Dist(e)); int64Data.push_back(int64Dist(e)); uint32Data.push_back(uint32Dist(e)); + + boolData1.push_back(boolDist(e)); + boolData2.push_back(boolDist(e)); + boolData3.push_back(boolDist(e)); + boolData4.push_back(boolDist(e)); } std::unordered_map< @@ -164,7 +192,11 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) { inputData{ {"int32Column", int32Data}, {"int64Column", int64Data}, - {"uint32Column", uint32Data}}; + {"uint32Column", uint32Data}, + {"boolColumn1", boolData1}, + {"boolColumn2", boolData2}, + {"boolColumn3", boolData3}, + {"boolColumn4", boolData4}}; auto serializedBytes = serializer0->serializeDataAsBytesForUDP(inputData, batchSize); @@ -189,9 +221,17 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) { auto int32Rst = std::get>(rst.at("int32Column")); auto int64Rst = std::get>(rst.at("int64Column")); auto uint32Rst = std::get>(rst.at("uint32Column")); + auto boolRst1 = std::get>(rst.at("boolColumn1")); + auto boolRst2 = std::get>(rst.at("boolColumn2")); + auto boolRst3 = std::get>(rst.at("boolColumn3")); + auto boolRst4 = std::get>(rst.at("boolColumn4")); testVectorEq(int32Data, int32Rst); testVectorEq(int64Data, int64Rst); testVectorEq(uint32Data, uint32Rst); + testVectorEq(boolData1, boolRst1); + testVectorEq(boolData2, boolRst2); + testVectorEq(boolData3, boolRst3); + testVectorEq(boolData4, boolRst4); } } // namespace fbpcf::mpc_std_lib::unified_data_process::serialization