diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h new file mode 100644 index 00000000..00b99597 --- /dev/null +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/FixedSizeArrayColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/PackedBitFieldColumn.h" + +#include "folly/Format.h" + +namespace fbpcf::mpc_std_lib::unified_data_process::serialization { + +template +class RowStructureDefinition : public IRowStructureDefinition { + public: + using SecString = frontend::BitString; + using SecBit = frontend::Bit; + + explicit RowStructureDefinition( + std::unique_ptr< + std::vector>>> + columnDefinitions) + : columnDefinitions_(std::move(columnDefinitions)) {} + + /* Returns the number of bytes to serialize a single row */ + size_t getRowSizeBytes() const override { + size_t rst = 0; + for (const auto& columnType : *columnDefinitions_.get()) { + rst += columnType->getColumnSizeBytes(); + } + + return rst; + } + + std::vector> serializeDataAsBytesForUDP( + const std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType>& + data, + int numRows) const override { + // validate number of columns matches what is expected + size_t expectedColumns = 0; + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + const PackedBitFieldColumn* packedBitCol = + dynamic_cast*>( + columnDefinition.get()); + + if (packedBitCol == nullptr) { + expectedColumns++; + } else { + expectedColumns += packedBitCol->getSubColumnNames().size(); + } + } + if (data.size() != expectedColumns) { + throw std::runtime_error( + "Mismatch between number of columns defined by row structure and what was passed in."); + } + + size_t byteOffset = 0; + + std::vector> writeBuffers( + numRows, std::vector(getRowSizeBytes())); + + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + const IColumnDefinition* columnPointer = + columnDefinition.get(); + auto columnType = columnDefinition->getColumnType(); + + switch (columnType) { + case IColumnDefinition::SupportedColumnTypes::UInt32: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + case IColumnDefinition::SupportedColumnTypes::Int32: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + case IColumnDefinition::SupportedColumnTypes::Int64: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + default: + throw std::runtime_error( + "Unknown column type while serializing data."); + } + + byteOffset += columnPointer->getColumnSizeBytes(); + } + + return writeBuffers; + } + + // Following a run of the UDP protocol, deserialize the array into pointers + // to MPC types. Data is represented in column order format + virtual std::unordered_map< + std::string, + typename IColumnDefinition::DeserializeType> + deserializeUDPOutputIntoMPCTypes( + const SecString& secretSharedData) const override { + std::vector> secretSharedBits = + secretSharedData.extractStringShare().getValue(); + secretSharedBits = transpose(secretSharedBits); + std::vector> secretSharedBytes(0); + secretSharedBytes.reserve(secretSharedBits.size()); + for (int i = 0; i < secretSharedBits.size(); i++) { + secretSharedBytes.push_back(convertFromBits(secretSharedBits[i])); + } + + std::unordered_map< + std::string, + typename IColumnDefinition::DeserializeType> + rst; + size_t byteOffset = 0; + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + rst.emplace( + columnDefinition->getColumnName(), + columnDefinition->deserializeSharesToMPCType( + secretSharedBytes, byteOffset)); + byteOffset += columnDefinition->getColumnSizeBytes(); + } + + return rst; + } + + private: + // use an ordered map for consistency between both parties + std::unique_ptr>>> + columnDefinitions_; + + std::vector convertFromBits( + const std::vector& data) const { + std::vector rst; + rst.reserve(data.size() / 8); + + size_t i = 0; + + while (i < data.size()) { + unsigned char val = 0; + size_t bitsLeft = data.size() - i > 8 ? 8 : data.size() - i; + for (auto j = 0; j < bitsLeft; j++) { + val |= (data[i] << j); + ++i; + } + rst.push_back(val); + } + + return rst; + } + + template + std::vector> transpose( + const std::vector>& data) const { + std::vector> result; + if (data.size() == 0) { + return result; + } + + result.reserve(data[0].size()); + for (size_t column = 0; column < data[0].size(); column++) { + std::vector innerArray(data.size()); + result.push_back(std::vector(data.size())); + for (size_t row = 0; row < data.size(); row++) { + result[column][row] = data[row][column]; + } + } + return result; + } + + template + void serializeIntegerColumn( + const IColumnDefinition* columnPointer, + const std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType>& + inputData, + std::vector>& writeBuffers, + int numRows, + size_t byteOffset) const { + std::string colName = columnPointer->getColumnName(); + + if (!inputData.contains(colName)) { + throw std::runtime_error(folly::sformat( + "Column {} which was defined in the structure was not included" + " in the input data map.", + colName)); + } + + using IntType = + typename IntegerColumn::NativeType; + + const std::vector intVals = + std::get>(inputData.at(colName)); + + if (intVals.size() != numRows) { + std::string err = folly::sformat( + "Invalid number of values for column {}. Got {} values but number of rows should be {} ", + colName, + intVals.size(), + numRows); + throw std::runtime_error(err); + } + + for (int i = 0; i < numRows; i++) { + columnPointer->serializeColumnAsPlaintextBytes( + &intVals[i], writeBuffers[i].data() + byteOffset); + } + } +}; + +} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp new file mode 100644 index 00000000..4c69f831 --- /dev/null +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "fbpcf/engine/communication/test/AgentFactoryCreationHelper.h" +#include "fbpcf/frontend/MPCTypes.h" +#include "fbpcf/scheduler/ISchedulerFactory.h" +#include "fbpcf/scheduler/SchedulerHelper.h" +#include "fbpcf/test/TestHelper.h" + +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h" + +namespace fbpcf::mpc_std_lib::unified_data_process::serialization { + +template +std::unique_ptr> createRowDefinition() { + auto columnDefs = std::make_unique< + std::vector>>>(0); + + columnDefs->push_back( + std::make_unique>("int32Column")); + columnDefs->push_back( + std::make_unique>("int64Column")); + columnDefs->push_back( + std::make_unique>("uint32Column")); + auto serializer = std::make_unique>( + std::move(columnDefs)); + return std::move(serializer); +} + +template +std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType> +deserializeAndRevealAllColumns( + fbpcf::scheduler::ISchedulerFactory& schedulerFactory, + const std::vector>& serializedSecretShares, + const std::unique_ptr>& + rowDefinition) { + auto scheduler = schedulerFactory.create(); + + fbpcf::scheduler::SchedulerKeeper::setScheduler( + std::move(scheduler)); + + // The vector of vector of bytes would be passed into UDP, and the SecString + // output will have less number of rows based on the intersection. We are + // pretending there is no data filtered out and party 0 creates the SecString + // as a private input. + std::vector> bitSharesTranspose( + serializedSecretShares[0].size() * 8, + std::vector(serializedSecretShares.size())); + + for (int i = 0; i < serializedSecretShares.size(); i++) { + for (int j = 0; j < serializedSecretShares[i].size(); j++) { + for (int k = 0; k < 8; k++) { + bitSharesTranspose[j * 8 + k][i] = + serializedSecretShares[i][j] >> k & 1; + } + } + } + + frontend::BitString udpOutput(bitSharesTranspose, 0); + + auto deserialization = + rowDefinition.get()->deserializeUDPOutputIntoMPCTypes(udpOutput); + + std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType> + rst; + + std::vector int32Opened = + std::get::Sec32Int>( + deserialization.at("int32Column")) + .openToParty(0) + .getValue(); + std::vector int32Data(int32Opened.size()); + std::transform( + int32Opened.begin(), + int32Opened.end(), + int32Data.begin(), + [](int64_t data) { return data; }); + rst.emplace("int32Column", int32Data); + + std::vector int64Data = + std::get::Sec64Int>( + deserialization.at("int64Column")) + .openToParty(0) + .getValue(); + + rst.emplace("int64Column", int64Data); + + std::vector uint32Opened = + std::get::SecUnsigned32Int>( + deserialization.at("uint32Column")) + .openToParty(0) + .getValue(); + + std::vector uint32Data(uint32Opened.size()); + std::transform( + uint32Opened.begin(), + uint32Opened.end(), + uint32Data.begin(), + [](uint64_t data) { return data; }); + rst.emplace("uint32Column", uint32Data); + + return rst; +} + +TEST(RowSerializationTest, RowWithMultipleColumnsTest) { + auto factories = fbpcf::engine::communication::getInMemoryAgentFactory(2); + + auto schedulerFactory0 = + fbpcf::scheduler::NetworkPlaintextSchedulerFactory( + 0, *factories[0]); + + auto schedulerFactory1 = + fbpcf::scheduler::NetworkPlaintextSchedulerFactory( + 1, *factories[1]); + + const size_t batchSize = 100; + + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution uint32Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + std::uniform_int_distribution int32Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + std::uniform_int_distribution int64Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + + std::unique_ptr> serializer0 = + createRowDefinition<0>(); + std::unique_ptr> serializer1 = + createRowDefinition<1>(); + + EXPECT_EQ(serializer0->getRowSizeBytes(), 16); + EXPECT_EQ(serializer1->getRowSizeBytes(), 16); + + std::vector int32Data(0); + std::vector int64Data(0); + std::vector uint32Data(0); + + for (int i = 0; i < batchSize; i++) { + int32Data.push_back(int32Dist(e)); + int64Data.push_back(int64Dist(e)); + uint32Data.push_back(uint32Dist(e)); + } + + std::unordered_map< + std::string, + IRowStructureDefinition<0>::InputColumnDataType> + inputData{ + {"int32Column", int32Data}, + {"int64Column", int64Data}, + {"uint32Column", uint32Data}}; + + auto serializedBytes = + serializer0->serializeDataAsBytesForUDP(inputData, batchSize); + + auto future0 = + std::async([&schedulerFactory0, &serializedBytes, &serializer0]() { + return deserializeAndRevealAllColumns<0>( + schedulerFactory0, serializedBytes, serializer0); + }); + + auto future1 = std::async([&schedulerFactory1, &serializer1]() { + return deserializeAndRevealAllColumns<1>( + schedulerFactory1, + std::vector>( + batchSize, std::vector(serializer1->getRowSizeBytes())), + serializer1); + }); + + auto rst = future0.get(); + future1.get(); + + auto int32Rst = std::get>(rst.at("int32Column")); + auto int64Rst = std::get>(rst.at("int64Column")); + auto uint32Rst = std::get>(rst.at("uint32Column")); + + testVectorEq(int32Data, int32Rst); + testVectorEq(int64Data, int64Rst); + testVectorEq(uint32Data, uint32Rst); +} +} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization