diff --git a/src/protocol/sort/apply_sort/shuffle.rs b/src/protocol/sort/apply_sort/shuffle.rs index 91919c44f..289b28dec 100644 --- a/src/protocol/sort/apply_sort/shuffle.rs +++ b/src/protocol/sort/apply_sort/shuffle.rs @@ -3,18 +3,14 @@ use ipa_macros::Step; use crate::{ error::Error, - helpers::Direction, protocol::{ - basics::{ - apply_permutation::{apply, apply_inv}, - Reshare, - }, + basics::Reshare, context::Context, sort::{ - shuffle::{shuffle_for_helper, ShuffleOrUnshuffle}, - ShuffleStep::{self, Shuffle1, Shuffle2, Shuffle3}, + shuffle::{shuffle_or_unshuffle_once, ShuffleOrUnshuffle}, + ShuffleStep::{Shuffle1, Shuffle2, Shuffle3}, }, - NoRecord, RecordId, + RecordId, }, }; @@ -30,40 +26,6 @@ impl From for InnerVectorElementStep { } } -/// `shuffle_once` is called for the helpers -/// i) 2 helpers receive permutation pair and choose the permutation to be applied -/// ii) 2 helpers apply the permutation to their shares -/// iii) reshare to `to_helper` -#[tracing::instrument(name = "shuffle_once", skip_all, fields(to = ?shuffle_for_helper(which_step)))] -async fn shuffle_once( - mut input: Vec, - random_permutations: (&[u32], &[u32]), - shuffle_or_unshuffle: ShuffleOrUnshuffle, - ctx: &C, - which_step: ShuffleStep, -) -> Result, Error> -where - C: Context, - I: Reshare + Send + Sync, -{ - let to_helper = shuffle_for_helper(which_step); - let ctx = ctx.narrow(&which_step); - - if to_helper != ctx.role() { - let permutation_to_apply = if to_helper.peer(Direction::Left) == ctx.role() { - random_permutations.0 - } else { - random_permutations.1 - }; - - match shuffle_or_unshuffle { - ShuffleOrUnshuffle::Shuffle => apply_inv(permutation_to_apply, &mut input), - ShuffleOrUnshuffle::Unshuffle => apply(permutation_to_apply, &mut input), - } - } - input.reshare(ctx, NoRecord, to_helper).await -} - #[embed_doc_image("shuffle", "images/sort/shuffle.png")] /// Shuffle calls `shuffle_once` three times with 2 helpers shuffling the shares each time. /// Order of calling `shuffle_once` is shuffle with (H2, H3), (H3, H1) and (H1, H2). @@ -85,7 +47,7 @@ where C: Context, I: Reshare + Send + Sync, { - let input = shuffle_once( + let input = shuffle_or_unshuffle_once( input, random_permutations, ShuffleOrUnshuffle::Shuffle, @@ -93,7 +55,7 @@ where Shuffle1, ) .await?; - let input = shuffle_once( + let input = shuffle_or_unshuffle_once( input, random_permutations, ShuffleOrUnshuffle::Shuffle, @@ -101,7 +63,7 @@ where Shuffle2, ) .await?; - shuffle_once( + shuffle_or_unshuffle_once( input, random_permutations, ShuffleOrUnshuffle::Shuffle, diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index d500bf0a2..e8ff70396 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -64,7 +64,8 @@ pub(super) fn shuffle_for_helper(which_step: ShuffleStep) -> Role { /// i) 2 helpers receive permutation pair and choose the permutation to be applied /// ii) 2 helpers apply the permutation to their shares /// iii) reshare to `to_helper` -async fn shuffle_or_unshuffle_once( +#[tracing::instrument(name = "shuffle_once", skip_all, fields(to = ?shuffle_for_helper(which_step)))] +pub async fn shuffle_or_unshuffle_once( mut input: Vec, random_permutations: (&[u32], &[u32]), shuffle_or_unshuffle: ShuffleOrUnshuffle, @@ -72,8 +73,6 @@ async fn shuffle_or_unshuffle_once( which_step: ShuffleStep, ) -> Result, Error> where - F: Field, - S: SecretSharing, C: Context, S: Reshare + Send + Sync, {