From f2bae560061ff6bd91cb64906ef6a7623d8747e4 Mon Sep 17 00:00:00 2001 From: Ruiyu Zhu Date: Thu, 9 Mar 2023 12:45:12 -0800 Subject: [PATCH] fix an unintended variable reuse in UdpEncryption object (#502) Summary: Pull Request resolved: https://github.com/facebookresearch/fbpcf/pull/502 `indexOffset_` is meant to record how much data has been processed. And my data and peer's data should be recorded separately. However the original implementation mistakenly used the same variable and the unit test failed to catch that. This diff fixes the bug and adds necessary tests. Reviewed By: haochenuw Differential Revision: D43746711 fbshipit-source-id: 94abdfdf4bd6890c32c7ce5019889a14e03fe3ab --- .../data_processor/UdpEncryption.cpp | 14 +- .../data_processor/UdpEncryption.h | 3 +- .../data_processor/test/DataProcessorTest.cpp | 127 ++++++++++++------ 3 files changed, 97 insertions(+), 47 deletions(-) diff --git a/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.cpp b/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.cpp index 791e85f5..78cefe09 100644 --- a/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.cpp +++ b/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.cpp @@ -28,7 +28,7 @@ void UdpEncryption::prepareToProcessMyData(size_t myDataWidth) { statusOfProcessingMyData_ = Status::inProgress; myDataWidth_ = myDataWidth; prgKey_ = fbpcf::engine::util::getRandomM128iFromSystemNoise(); - indexOffset_ = 0; + myDataIndexOffset_ = 0; } void UdpEncryption::processMyData( @@ -45,12 +45,12 @@ void UdpEncryption::processMyData( " but get " + std::to_string(plaintextData.at(0).size())); } auto [ciphertext, nonce] = - UdpUtil::localEncryption(plaintextData, prgKey_, indexOffset_); + UdpUtil::localEncryption(plaintextData, prgKey_, myDataIndexOffset_); agent_->send(nonce); for (size_t i = 0; i < ciphertext.size(); i++) { agent_->send(ciphertext.at(i)); } - indexOffset_ += plaintextData.size(); + myDataIndexOffset_ += plaintextData.size(); } void UdpEncryption::prepareToProcessPeerData( @@ -67,7 +67,7 @@ void UdpEncryption::prepareToProcessPeerData( } peerDataWidth_ = peerDataWidth; - indexOffset_ = 0; + peerDataIndexOffset_ = 0; cherryPickedEncryption_ = std::vector>(indexes.size()); @@ -87,17 +87,17 @@ void UdpEncryption::processPeerData(size_t dataSize) { for (size_t i = 0; i < dataSize; i++) { auto ciphertext = agent_->receive(peerDataWidth_); - auto pos = indexToOrderMap_.find(i + indexOffset_); + auto pos = indexToOrderMap_.find(i + peerDataIndexOffset_); if (pos != indexToOrderMap_.end()) { // this ciphertext should be picked up cherryPickedEncryption_.at(pos->second) = std::move(ciphertext); cherryPickedNonce_.at(pos->second) = nonce; - cherryPickedIndex_.at(pos->second) = i + indexOffset_; + cherryPickedIndex_.at(pos->second) = i + peerDataIndexOffset_; indexToOrderMap_.erase(pos); // TODO: this can be further optimized by not copying duplicated nonce. } } - indexOffset_ += dataSize; + peerDataIndexOffset_ += dataSize; } } // namespace fbpcf::mpc_std_lib::unified_data_process::data_processor diff --git a/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.h b/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.h index c9130e1c..5094e338 100644 --- a/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.h +++ b/fbpcf/mpc_std_lib/unified_data_process/data_processor/UdpEncryption.h @@ -80,7 +80,8 @@ class UdpEncryption { std::unique_ptr agent_; - uint64_t indexOffset_; + uint64_t myDataIndexOffset_; + uint64_t peerDataIndexOffset_; size_t myDataWidth_; __m128i prgKey_; diff --git a/fbpcf/mpc_std_lib/unified_data_process/data_processor/test/DataProcessorTest.cpp b/fbpcf/mpc_std_lib/unified_data_process/data_processor/test/DataProcessorTest.cpp index d377c225..49ae3b5f 100644 --- a/fbpcf/mpc_std_lib/unified_data_process/data_processor/test/DataProcessorTest.cpp +++ b/fbpcf/mpc_std_lib/unified_data_process/data_processor/test/DataProcessorTest.cpp @@ -140,6 +140,23 @@ std::tuple, std::vector> split( std::vector(src.begin() + cutPosition, src.end())}; } +std::vector> convertToVecs( + size_t rowCount, + size_t dataWidth, + const std::vector>& src) { + std::vector> rst( + rowCount, std::vector(dataWidth)); + + for (size_t i = 0; i < dataWidth; i++) { + for (uint8_t j = 0; j < 8; j++) { + for (size_t k = 0; k < rowCount; k++) { + rst.at(k).at(i) += (src.at(i * 8 + j).at(k) << j); + } + } + } + return rst; +} + void testUdpEncryptionAndDecryptionObjects( std::unique_ptr agent0, std::unique_ptr agent1) { @@ -162,21 +179,28 @@ void testUdpEncryptionAndDecryptionObjects( auto udpDec10 = std::make_unique>(0, 1); auto udpDec11 = std::make_unique>(1, 0); - auto task0 = [numberOfInputShards]( - std::unique_ptr udpEnc, - std::unique_ptr> udpDec0, - std::unique_ptr> udpDec1, - const std::vector>>& - plaintextDataInShards, - size_t dataWidth, - size_t outputSize) { + auto task0 = [](std::unique_ptr udpEnc, + std::unique_ptr> udpDec0, + std::unique_ptr> udpDec1, + const std::vector>>& + plaintextDataInShards, + size_t dataWidth, + size_t outputSize, + const std::vector& indexes, + const std::vector& sizes) { udpEnc->prepareToProcessMyData(dataWidth); - for (size_t i = 0; i < numberOfInputShards; i++) { + for (size_t i = 0; i < plaintextDataInShards.size(); i++) { udpEnc->processMyData(plaintextDataInShards.at(i)); }; + udpEnc->prepareToProcessPeerData(dataWidth, indexes); + for (size_t i = 0; i < sizes.size(); i++) { + udpEnc->processPeerData(sizes.at(i)); + } + size_t outputShard0Size = outputSize / 2; size_t outputShard1Size = outputSize - outputShard0Size; auto key = udpEnc->getExpandedKey(); + auto result0 = udpDec0->decryptMyData(key, dataWidth, outputShard0Size) .openToParty(0) .getValue(); @@ -184,47 +208,65 @@ void testUdpEncryptionAndDecryptionObjects( .openToParty(0) .getValue(); - std::vector> rst0( - outputShard0Size, std::vector(dataWidth)); - for (size_t i = 0; i < dataWidth; i++) { - for (uint8_t j = 0; j < 8; j++) { - for (size_t k = 0; k < outputShard0Size; k++) { - rst0.at(k).at(i) += (result0.at(i * 8 + j).at(k) << j); - } - } - } - - std::vector> rst1( - outputShard1Size, std::vector(dataWidth)); - for (size_t i = 0; i < dataWidth; i++) { - for (uint8_t j = 0; j < 8; j++) { - for (size_t k = 0; k < outputShard1Size; k++) { - rst1.at(k).at(i) += (result1.at(i * 8 + j).at(k) << j); - } - } - } + auto rst0 = convertToVecs(outputShard0Size, dataWidth, result0); + auto rst1 = convertToVecs(outputShard1Size, dataWidth, result1); rst0.insert(rst0.end(), rst1.begin(), rst1.end()); - return rst0; + + auto [intersection, nonces, pickedIndexes] = udpEnc->getProcessedData(); + + auto [intersection0, intersection1] = split(intersection, outputShard0Size); + auto [nonces0, nonces1] = split(nonces, outputShard0Size); + auto [indexes0, indexes1] = split(pickedIndexes, outputShard0Size); + + auto rst2 = convertToVecs( + outputShard0Size, + dataWidth, + udpDec0->decryptPeerData(intersection0, nonces0, indexes0) + .openToParty(0) + .getValue()); + auto rst3 = convertToVecs( + outputShard1Size, + dataWidth, + udpDec1->decryptPeerData(intersection1, nonces1, indexes1) + .openToParty(0) + .getValue()); + rst2.insert(rst2.end(), rst3.begin(), rst3.end()); + return std::make_tuple(rst0, rst2); }; - auto task1 = [numberOfInputShards, &dataWidth]( - std::unique_ptr udpEnc, - std::unique_ptr> udpDec0, - std::unique_ptr> udpDec1, - const std::vector& indexes, - const std::vector& sizes) { + auto task1 = [](std::unique_ptr udpEnc, + std::unique_ptr> udpDec0, + std::unique_ptr> udpDec1, + const std::vector& indexes, + const std::vector& sizes, + const std::vector>>& + plaintextDataInShards, + size_t dataWidth) { udpEnc->prepareToProcessPeerData(dataWidth, indexes); - for (size_t i = 0; i < numberOfInputShards; i++) { + for (size_t i = 0; i < sizes.size(); i++) { udpEnc->processPeerData(sizes.at(i)); } + udpEnc->prepareToProcessMyData(dataWidth); + for (size_t i = 0; i < plaintextDataInShards.size(); i++) { + udpEnc->processMyData(plaintextDataInShards.at(i)); + }; + auto [intersection, nonces, pickedIndexes] = udpEnc->getProcessedData(); + size_t outputShard0Size = intersection.size() / 2; + size_t outputShard1Size = intersection.size() - outputShard0Size; + auto [intersection0, intersection1] = split(intersection, outputShard0Size); auto [nonces0, nonces1] = split(nonces, outputShard0Size); auto [indexes0, indexes1] = split(pickedIndexes, outputShard0Size); udpDec0->decryptPeerData(intersection0, nonces0, indexes0).openToParty(0); udpDec1->decryptPeerData(intersection1, nonces1, indexes1).openToParty(0); + + auto key = udpEnc->getExpandedKey(); + + udpDec0->decryptMyData(key, dataWidth, outputShard0Size).openToParty(0); + udpDec1->decryptMyData(key, dataWidth, outputShard1Size).openToParty(0); }; auto future1 = std::async( @@ -233,18 +275,25 @@ void testUdpEncryptionAndDecryptionObjects( std::move(udpDec01), std::move(udpDec11), indexes, - sizes); + sizes, + shard, + dataWidth); auto rst = task0( std::move(udpEnc0), std::move(udpDec00), std::move(udpDec10), shard, dataWidth, - outputSize); + outputSize, + indexes, + sizes); future1.get(); for (size_t i = 0; i < outputSize; i++) { - fbpcf::testVectorEq(rst.at(i), expectedOutput.at(i)); + fbpcf::testVectorEq(std::get<0>(rst).at(i), expectedOutput.at(i)); + } + for (size_t i = 0; i < outputSize; i++) { + fbpcf::testVectorEq(std::get<1>(rst).at(i), expectedOutput.at(i)); } }