From 4f8705f5800c46c1cf09f2d17625fa423f427a55 Mon Sep 17 00:00:00 2001 From: Tal Davidi Date: Thu, 23 Feb 2023 10:18:06 -0800 Subject: [PATCH] Add RowDefinition implementation with integer column support (#490) Summary: Pull Request resolved: https://github.com/facebookresearch/fbpcf/pull/490 # Background: Currently in order to successfully use UDP, you must write some carefully crafted code that will take all the rows of metadata for one side and package it into a collection of bytes. Afterwards the caller will get a `SecString` object back which is a bit representation of all the bytes they passed in, minus the filtered out rows. The user must then extract the corresponding bits for each column into separate MPC Types. This is a cumbersome process which is error prone, as you must make sure to carefully match up the two steps and any changes can cause a bug. # This Diff Creates the implementation for the RowDefinition. As a start we only support integer columns, more types to come in later diff. Reviewed By: RuiyuZhu Differential Revision: D43208067 fbshipit-source-id: 94a1141233a0a613d3ded6a8975f0ddec0d6da17 --- .../serialization/RowStructureDefinition.h | 223 ++++++++++++++++++ .../test/RowSerializationTest.cpp | 197 ++++++++++++++++ 2 files changed, 420 insertions(+) create mode 100644 fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h create mode 100644 fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h new file mode 100644 index 00000000..00b99597 --- /dev/null +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/FixedSizeArrayColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/PackedBitFieldColumn.h" + +#include "folly/Format.h" + +namespace fbpcf::mpc_std_lib::unified_data_process::serialization { + +template +class RowStructureDefinition : public IRowStructureDefinition { + public: + using SecString = frontend::BitString; + using SecBit = frontend::Bit; + + explicit RowStructureDefinition( + std::unique_ptr< + std::vector>>> + columnDefinitions) + : columnDefinitions_(std::move(columnDefinitions)) {} + + /* Returns the number of bytes to serialize a single row */ + size_t getRowSizeBytes() const override { + size_t rst = 0; + for (const auto& columnType : *columnDefinitions_.get()) { + rst += columnType->getColumnSizeBytes(); + } + + return rst; + } + + std::vector> serializeDataAsBytesForUDP( + const std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType>& + data, + int numRows) const override { + // validate number of columns matches what is expected + size_t expectedColumns = 0; + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + const PackedBitFieldColumn* packedBitCol = + dynamic_cast*>( + columnDefinition.get()); + + if (packedBitCol == nullptr) { + expectedColumns++; + } else { + expectedColumns += packedBitCol->getSubColumnNames().size(); + } + } + if (data.size() != expectedColumns) { + throw std::runtime_error( + "Mismatch between number of columns defined by row structure and what was passed in."); + } + + size_t byteOffset = 0; + + std::vector> writeBuffers( + numRows, std::vector(getRowSizeBytes())); + + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + const IColumnDefinition* columnPointer = + columnDefinition.get(); + auto columnType = columnDefinition->getColumnType(); + + switch (columnType) { + case IColumnDefinition::SupportedColumnTypes::UInt32: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + case IColumnDefinition::SupportedColumnTypes::Int32: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + case IColumnDefinition::SupportedColumnTypes::Int64: + serializeIntegerColumn( + columnPointer, data, writeBuffers, numRows, byteOffset); + break; + default: + throw std::runtime_error( + "Unknown column type while serializing data."); + } + + byteOffset += columnPointer->getColumnSizeBytes(); + } + + return writeBuffers; + } + + // Following a run of the UDP protocol, deserialize the array into pointers + // to MPC types. Data is represented in column order format + virtual std::unordered_map< + std::string, + typename IColumnDefinition::DeserializeType> + deserializeUDPOutputIntoMPCTypes( + const SecString& secretSharedData) const override { + std::vector> secretSharedBits = + secretSharedData.extractStringShare().getValue(); + secretSharedBits = transpose(secretSharedBits); + std::vector> secretSharedBytes(0); + secretSharedBytes.reserve(secretSharedBits.size()); + for (int i = 0; i < secretSharedBits.size(); i++) { + secretSharedBytes.push_back(convertFromBits(secretSharedBits[i])); + } + + std::unordered_map< + std::string, + typename IColumnDefinition::DeserializeType> + rst; + size_t byteOffset = 0; + for (const std::unique_ptr>& + columnDefinition : *columnDefinitions_.get()) { + rst.emplace( + columnDefinition->getColumnName(), + columnDefinition->deserializeSharesToMPCType( + secretSharedBytes, byteOffset)); + byteOffset += columnDefinition->getColumnSizeBytes(); + } + + return rst; + } + + private: + // use an ordered map for consistency between both parties + std::unique_ptr>>> + columnDefinitions_; + + std::vector convertFromBits( + const std::vector& data) const { + std::vector rst; + rst.reserve(data.size() / 8); + + size_t i = 0; + + while (i < data.size()) { + unsigned char val = 0; + size_t bitsLeft = data.size() - i > 8 ? 8 : data.size() - i; + for (auto j = 0; j < bitsLeft; j++) { + val |= (data[i] << j); + ++i; + } + rst.push_back(val); + } + + return rst; + } + + template + std::vector> transpose( + const std::vector>& data) const { + std::vector> result; + if (data.size() == 0) { + return result; + } + + result.reserve(data[0].size()); + for (size_t column = 0; column < data[0].size(); column++) { + std::vector innerArray(data.size()); + result.push_back(std::vector(data.size())); + for (size_t row = 0; row < data.size(); row++) { + result[column][row] = data[row][column]; + } + } + return result; + } + + template + void serializeIntegerColumn( + const IColumnDefinition* columnPointer, + const std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType>& + inputData, + std::vector>& writeBuffers, + int numRows, + size_t byteOffset) const { + std::string colName = columnPointer->getColumnName(); + + if (!inputData.contains(colName)) { + throw std::runtime_error(folly::sformat( + "Column {} which was defined in the structure was not included" + " in the input data map.", + colName)); + } + + using IntType = + typename IntegerColumn::NativeType; + + const std::vector intVals = + std::get>(inputData.at(colName)); + + if (intVals.size() != numRows) { + std::string err = folly::sformat( + "Invalid number of values for column {}. Got {} values but number of rows should be {} ", + colName, + intVals.size(), + numRows); + throw std::runtime_error(err); + } + + for (int i = 0; i < numRows; i++) { + columnPointer->serializeColumnAsPlaintextBytes( + &intVals[i], writeBuffers[i].data() + byteOffset); + } + } +}; + +} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization diff --git a/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp new file mode 100644 index 00000000..4c69f831 --- /dev/null +++ b/fbpcf/mpc_std_lib/unified_data_process/serialization/test/RowSerializationTest.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "fbpcf/engine/communication/test/AgentFactoryCreationHelper.h" +#include "fbpcf/frontend/MPCTypes.h" +#include "fbpcf/scheduler/ISchedulerFactory.h" +#include "fbpcf/scheduler/SchedulerHelper.h" +#include "fbpcf/test/TestHelper.h" + +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h" +#include "fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h" + +namespace fbpcf::mpc_std_lib::unified_data_process::serialization { + +template +std::unique_ptr> createRowDefinition() { + auto columnDefs = std::make_unique< + std::vector>>>(0); + + columnDefs->push_back( + std::make_unique>("int32Column")); + columnDefs->push_back( + std::make_unique>("int64Column")); + columnDefs->push_back( + std::make_unique>("uint32Column")); + auto serializer = std::make_unique>( + std::move(columnDefs)); + return std::move(serializer); +} + +template +std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType> +deserializeAndRevealAllColumns( + fbpcf::scheduler::ISchedulerFactory& schedulerFactory, + const std::vector>& serializedSecretShares, + const std::unique_ptr>& + rowDefinition) { + auto scheduler = schedulerFactory.create(); + + fbpcf::scheduler::SchedulerKeeper::setScheduler( + std::move(scheduler)); + + // The vector of vector of bytes would be passed into UDP, and the SecString + // output will have less number of rows based on the intersection. We are + // pretending there is no data filtered out and party 0 creates the SecString + // as a private input. + std::vector> bitSharesTranspose( + serializedSecretShares[0].size() * 8, + std::vector(serializedSecretShares.size())); + + for (int i = 0; i < serializedSecretShares.size(); i++) { + for (int j = 0; j < serializedSecretShares[i].size(); j++) { + for (int k = 0; k < 8; k++) { + bitSharesTranspose[j * 8 + k][i] = + serializedSecretShares[i][j] >> k & 1; + } + } + } + + frontend::BitString udpOutput(bitSharesTranspose, 0); + + auto deserialization = + rowDefinition.get()->deserializeUDPOutputIntoMPCTypes(udpOutput); + + std::unordered_map< + std::string, + typename IRowStructureDefinition::InputColumnDataType> + rst; + + std::vector int32Opened = + std::get::Sec32Int>( + deserialization.at("int32Column")) + .openToParty(0) + .getValue(); + std::vector int32Data(int32Opened.size()); + std::transform( + int32Opened.begin(), + int32Opened.end(), + int32Data.begin(), + [](int64_t data) { return data; }); + rst.emplace("int32Column", int32Data); + + std::vector int64Data = + std::get::Sec64Int>( + deserialization.at("int64Column")) + .openToParty(0) + .getValue(); + + rst.emplace("int64Column", int64Data); + + std::vector uint32Opened = + std::get::SecUnsigned32Int>( + deserialization.at("uint32Column")) + .openToParty(0) + .getValue(); + + std::vector uint32Data(uint32Opened.size()); + std::transform( + uint32Opened.begin(), + uint32Opened.end(), + uint32Data.begin(), + [](uint64_t data) { return data; }); + rst.emplace("uint32Column", uint32Data); + + return rst; +} + +TEST(RowSerializationTest, RowWithMultipleColumnsTest) { + auto factories = fbpcf::engine::communication::getInMemoryAgentFactory(2); + + auto schedulerFactory0 = + fbpcf::scheduler::NetworkPlaintextSchedulerFactory( + 0, *factories[0]); + + auto schedulerFactory1 = + fbpcf::scheduler::NetworkPlaintextSchedulerFactory( + 1, *factories[1]); + + const size_t batchSize = 100; + + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution uint32Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + std::uniform_int_distribution int32Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + std::uniform_int_distribution int64Dist( + std::numeric_limits().min(), + std::numeric_limits().max()); + + std::unique_ptr> serializer0 = + createRowDefinition<0>(); + std::unique_ptr> serializer1 = + createRowDefinition<1>(); + + EXPECT_EQ(serializer0->getRowSizeBytes(), 16); + EXPECT_EQ(serializer1->getRowSizeBytes(), 16); + + std::vector int32Data(0); + std::vector int64Data(0); + std::vector uint32Data(0); + + for (int i = 0; i < batchSize; i++) { + int32Data.push_back(int32Dist(e)); + int64Data.push_back(int64Dist(e)); + uint32Data.push_back(uint32Dist(e)); + } + + std::unordered_map< + std::string, + IRowStructureDefinition<0>::InputColumnDataType> + inputData{ + {"int32Column", int32Data}, + {"int64Column", int64Data}, + {"uint32Column", uint32Data}}; + + auto serializedBytes = + serializer0->serializeDataAsBytesForUDP(inputData, batchSize); + + auto future0 = + std::async([&schedulerFactory0, &serializedBytes, &serializer0]() { + return deserializeAndRevealAllColumns<0>( + schedulerFactory0, serializedBytes, serializer0); + }); + + auto future1 = std::async([&schedulerFactory1, &serializer1]() { + return deserializeAndRevealAllColumns<1>( + schedulerFactory1, + std::vector>( + batchSize, std::vector(serializer1->getRowSizeBytes())), + serializer1); + }); + + auto rst = future0.get(); + future1.get(); + + auto int32Rst = std::get>(rst.at("int32Column")); + auto int64Rst = std::get>(rst.at("int64Column")); + auto uint32Rst = std::get>(rst.at("uint32Column")); + + testVectorEq(int32Data, int32Rst); + testVectorEq(int64Data, int64Rst); + testVectorEq(uint32Data, uint32Rst); +} +} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization