diff --git a/CHANGELOG.md b/CHANGELOG.md index 418d660..10de312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `ff::Field::sum_of_products` ## [0.12.1] - 2022-10-28 ### Fixed diff --git a/src/lib.rs b/src/lib.rs index 4ccd0de..fa54c91 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,21 @@ pub trait Field: res } + + /// Returns `pairs.into_iter().fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)`. + /// + /// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length + /// sequences of elements `a` and `b`, such that `pairs = a.zip(b)`. + /// + /// The provided implementation of this trait method uses the direct calculation given + /// above. Implementations of `Field` should override this to use more efficient + /// methods that take advantage of their internal representation, such as interleaving + /// or sharing modular reductions. + fn sum_of_products<'a, I: IntoIterator + Clone>(pairs: I) -> Self { + pairs + .into_iter() + .fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i)) + } } /// This represents an element of a prime field. diff --git a/tests/derive.rs b/tests/derive.rs index bfa2cd2..0ae7e3c 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -37,6 +37,47 @@ mod full_limbs { } } +#[test] +fn sum_of_products() { + use ff::{Field, PrimeField}; + + let one = Bls381K12Scalar::one(); + + // [1, 2, 3, 4] + let values: Vec<_> = (0..4) + .scan(one, |acc, _| { + let ret = *acc; + *acc += &one; + Some(ret) + }) + .collect(); + + // We'll pair each value with itself. + let expected = Bls381K12Scalar::from_str_vartime("30").unwrap(); + + // Check that we can produce the necessary input from two iterators. + assert_eq!( + // Directly produces (&v, &v) + Bls381K12Scalar::sum_of_products(values.iter().zip(values.iter())), + expected, + ); + + // Check that we can produce the necessary input from an iterator of values. + assert_eq!( + // Maps &v to (&v, &v) + Bls381K12Scalar::sum_of_products(values.iter().map(|v| (v, v))), + expected, + ); + + // Check that we can produce the necessary input from an iterator of tuples. + let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect(); + assert_eq!( + // Maps &(a, b) to (&a, &b) + Bls381K12Scalar::sum_of_products(tuples.iter().map(|(a, b)| (a, b))), + expected, + ); +} + #[test] fn batch_inversion() { use ff::{BatchInverter, Field};