Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Field::sum_of_products method #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `ff::Field::sum_of_products`

## [0.12.1] - 2022-10-28
### Fixed
Expand Down
15 changes: 15 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ pub trait Field:

res
}

/// Returns `pairs.into_iter().fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)`.
///
/// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length
/// sequences of elements `a` and `b`, such that `pairs = a.zip(b)`.
///
/// The provided implementation of this trait method uses the direct calculation given
/// above. Implementations of `Field` should override this to use more efficient
/// methods that take advantage of their internal representation, such as interleaving
/// or sharing modular reductions.
fn sum_of_products<'a, I: IntoIterator<Item = (&'a Self, &'a Self)> + Clone>(pairs: I) -> Self {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally made this an iterator of &(a, b), and I tested this against the array impl and saw little-to-no performance difference. I then changed it to an iterator of (&a, &b) because I thought it would be more flexible, but doing that significantly hurts performance (I presume because we're calling .clone() on an owned tuple that contains references, rather than on a reference).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, even using nicer tricks for constructing the input arrays, having this API take (&a, &b) causes bls12_381 pairings to be over 5% slower compared to arrays or &(a, b).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried two iterators instead:

fn sum_of_products<'a>(
        a: impl IntoIterator<Item = &'a Self> + Clone,
        b: impl IntoIterator<Item = &'a Self> + Clone,
    ) -> Self

and then a.clone().into_iter().zip(b.clone().into_iter()); this is similarly slower than arrays.

pairs
.into_iter()
.fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i))
}
}

/// This represents an element of a prime field.
Expand Down
41 changes: 41 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,47 @@ mod full_limbs {
}
}

#[test]
fn sum_of_products() {
use ff::{Field, PrimeField};

let one = Bls381K12Scalar::one();

// [1, 2, 3, 4]
let values: Vec<_> = (0..4)
.scan(one, |acc, _| {
let ret = *acc;
*acc += &one;
Some(ret)
})
.collect();

// We'll pair each value with itself.
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();

// Check that we can produce the necessary input from two iterators.
assert_eq!(
// Directly produces (&v, &v)
Bls381K12Scalar::sum_of_products(values.iter().zip(values.iter())),
expected,
);

// Check that we can produce the necessary input from an iterator of values.
assert_eq!(
// Maps &v to (&v, &v)
Bls381K12Scalar::sum_of_products(values.iter().map(|v| (v, v))),
expected,
);

// Check that we can produce the necessary input from an iterator of tuples.
let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect();
assert_eq!(
// Maps &(a, b) to (&a, &b)
Bls381K12Scalar::sum_of_products(tuples.iter().map(|(a, b)| (a, b))),
expected,
);
}

#[test]
fn batch_inversion() {
use ff::{BatchInverter, Field};
Expand Down