From 08e98136881fdb7d8f4c9cdf55da904c7226dfed Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 28 Nov 2023 22:14:38 -0800 Subject: [PATCH 1/5] Use `Integer` and `Monty` instead of `Uint` --- Cargo.toml | 2 +- benches/bench.rs | 44 +++++++------- src/hazmat/gcd.rs | 8 +-- src/hazmat/jacobi.rs | 71 ++++++++++------------ src/hazmat/lucas.rs | 98 +++++++++++++++--------------- src/hazmat/miller_rabin.rs | 70 +++++++++++----------- src/hazmat/sieve.rs | 94 ++++++++++++++--------------- src/lib.rs | 3 + src/presets.rs | 87 +++++++++++++-------------- src/traits.rs | 15 ++--- src/uint_traits.rs | 120 +++++++++++++++++++++++++++++++++++++ 11 files changed, 361 insertions(+), 251 deletions(-) create mode 100644 src/uint_traits.rs diff --git a/Cargo.toml b/Cargo.toml index 54b05f6..d1fcea9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ categories = ["cryptography", "no-std"] rust-version = "1.73" [dependencies] -crypto-bigint = { version = "0.6.0-pre.5", default-features = false, features = ["rand_core"] } +crypto-bigint = { version = "0.6.0-pre.5", default-features = false, features = ["alloc", "rand_core"] } rand_core = { version = "0.6.4", default-features = false } openssl = { version = "0.10.39", optional = true, features = ["vendored"] } rug = { version = "1.18", default-features = false, features = ["integer"], optional = true } diff --git a/benches/bench.rs b/benches/bench.rs index 48eb504..edd5637 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -22,8 +22,8 @@ fn make_rng() -> ChaCha8Rng { ChaCha8Rng::from_seed(*b"01234567890123456789012345678901") } -fn make_sieve(rng: &mut impl CryptoRngCore) -> Sieve { - let start = random_odd_uint::(rng, Uint::::BITS); +fn make_sieve(rng: &mut impl CryptoRngCore) -> Sieve> { + let start = random_odd_uint::>(rng, Uint::::BITS); Sieve::new(&start, Uint::::BITS, false) } @@ -36,13 +36,13 @@ fn bench_sieve(c: &mut Criterion) { let mut group = c.benchmark_group("Sieve"); group.bench_function("(U128) random start", |b| { - b.iter(|| random_odd_uint::<{ nlimbs!(128) }>(&mut OsRng, 128)) + b.iter(|| random_odd_uint::>(&mut OsRng, 128)) }); group.bench_function("(U128) creation", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(128) }>(&mut OsRng, 128), - |start| Sieve::new(&start, 128, false), + || random_odd_uint::>(&mut OsRng, 128), + |start| Sieve::new(start.as_ref(), 128, false), BatchSize::SmallInput, ) }); @@ -57,13 +57,13 @@ fn bench_sieve(c: &mut Criterion) { }); group.bench_function("(U1024) random start", |b| { - b.iter(|| random_odd_uint::<{ nlimbs!(1024) }>(&mut OsRng, 1024)) + b.iter(|| random_odd_uint::>(&mut OsRng, 1024)) }); group.bench_function("(U1024) creation", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(1024) }>(&mut OsRng, 1024), - |start| Sieve::new(&start, 1024, false), + || random_odd_uint::>(&mut OsRng, 1024), + |start| Sieve::new(start.as_ref(), 1024, false), BatchSize::SmallInput, ) }); @@ -84,7 +84,7 @@ fn bench_miller_rabin(c: &mut Criterion) { group.bench_function("(U128) creation", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(128) }>(&mut OsRng, 128), + || random_odd_uint::>(&mut OsRng, 128), MillerRabin::new, BatchSize::SmallInput, ) @@ -100,7 +100,7 @@ fn bench_miller_rabin(c: &mut Criterion) { group.bench_function("(U1024) creation", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(1024) }>(&mut OsRng, 1024), + || random_odd_uint::>(&mut OsRng, 1024), MillerRabin::new, BatchSize::SmallInput, ) @@ -193,39 +193,39 @@ fn bench_presets(c: &mut Criterion) { group.bench_function("(U128) Prime test", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(128) }>(&mut OsRng, 128), - |num| is_prime_with_rng(&mut OsRng, &num), + || random_odd_uint::>(&mut OsRng, 128), + |num| is_prime_with_rng(&mut OsRng, num.as_ref()), BatchSize::SmallInput, ) }); group.bench_function("(U128) Safe prime test", |b| { b.iter_batched( - || random_odd_uint::<{ nlimbs!(128) }>(&mut OsRng, 128), - |num| is_safe_prime_with_rng(&mut OsRng, &num), + || random_odd_uint::>(&mut OsRng, 128), + |num| is_safe_prime_with_rng(&mut OsRng, num.as_ref()), BatchSize::SmallInput, ) }); let mut rng = make_rng(); group.bench_function("(U128) Random prime", |b| { - b.iter(|| generate_prime_with_rng::<{ nlimbs!(128) }>(&mut rng, None)) + b.iter(|| generate_prime_with_rng::>(&mut rng, 128)) }); let mut rng = make_rng(); group.bench_function("(U1024) Random prime", |b| { - b.iter(|| generate_prime_with_rng::<{ nlimbs!(1024) }>(&mut rng, None)) + b.iter(|| generate_prime_with_rng::>(&mut rng, 1024)) }); let mut rng = make_rng(); group.bench_function("(U128) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::<{ nlimbs!(128) }>(&mut rng, None)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) }); group.sample_size(20); let mut rng = make_rng(); group.bench_function("(U1024) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::<{ nlimbs!(1024) }>(&mut rng, None)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 1024)) }); group.finish(); @@ -235,19 +235,19 @@ fn bench_presets(c: &mut Criterion) { let mut rng = make_rng(); group.bench_function("(U128) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::<{ nlimbs!(128) }>(&mut rng, None)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) }); // The performance should scale with the prime size, not with the Uint size. // So we should strive for this test's result to be as close as possible // to that of the previous one and as far away as possible from the next one. group.bench_function("(U256) Random 128 bit safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::<{ nlimbs!(256) }>(&mut rng, Some(128))) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) }); // The upper bound for the previous test. group.bench_function("(U256) Random 256 bit safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::<{ nlimbs!(256) }>(&mut rng, None)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 256)) }); group.finish(); @@ -258,7 +258,7 @@ fn bench_gmp(c: &mut Criterion) { let mut group = c.benchmark_group("GMP"); fn random(rng: &mut impl CryptoRngCore) -> Integer { - let num = random_odd_uint::(rng, Uint::::BITS); + let num = random_odd_uint::>(rng, Uint::::BITS).get(); Integer::from_digits(num.as_words(), Order::Lsf) } diff --git a/src/hazmat/gcd.rs b/src/hazmat/gcd.rs index c51e4a2..57a0c85 100644 --- a/src/hazmat/gcd.rs +++ b/src/hazmat/gcd.rs @@ -1,9 +1,9 @@ -use crypto_bigint::{Limb, NonZero, Uint, Word}; +use crypto_bigint::{Integer, Limb, NonZero, Word}; /// Calculates the greatest common divisor of `n` and `m`. /// By definition, `gcd(0, m) == m`. /// `n` must be non-zero. -pub(crate) fn gcd_vartime(n: &Uint, m: Word) -> Word { +pub(crate) fn gcd_vartime(n: &T, m: Word) -> Word { // This is an internal function, and it will never be called with `m = 0`. // Allowing `m = 0` would require us to have the return type of `Uint` // (since `gcd(n, 0) = n`). @@ -11,7 +11,7 @@ pub(crate) fn gcd_vartime(n: &Uint, m: Word) -> Word { // This we can check since it doesn't affect the return type, // even though `n` will not be 0 either in the application. - if n == &Uint::::ZERO { + if n.is_zero().into() { return m; } @@ -23,7 +23,7 @@ pub(crate) fn gcd_vartime(n: &Uint, m: Word) -> Word { } else { // In this branch `n` is `Word::BITS` bits or shorter, // so we can safely take the first limb. - let n = n.as_words()[0]; + let n = n.as_ref()[0].0; if n > m { (n, m) } else { diff --git a/src/hazmat/jacobi.rs b/src/hazmat/jacobi.rs index f45434f..c435872 100644 --- a/src/hazmat/jacobi.rs +++ b/src/hazmat/jacobi.rs @@ -1,6 +1,6 @@ //! Jacobi symbol calculation. -use crypto_bigint::{Limb, NonZero, Odd, Uint, Word}; +use crypto_bigint::{Integer, Limb, NonZero, Odd, Word}; #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub(crate) enum JacobiSymbol { @@ -20,37 +20,13 @@ impl core::ops::Neg for JacobiSymbol { } } -// A helper trait to generalize some functions over Word and Uint. -trait SmallMod { - fn mod8(&self) -> Word; - fn mod4(&self) -> Word; -} - -impl SmallMod for Word { - fn mod8(&self) -> Word { - self & 7 - } - fn mod4(&self) -> Word { - self & 3 - } -} - -impl SmallMod for Uint { - fn mod8(&self) -> Word { - self.as_limbs()[0].0 & 7 - } - fn mod4(&self) -> Word { - self.as_limbs()[0].0 & 3 - } -} - /// Transforms `(a/p)` -> `(r/p)` for odd `p`, where the resulting `r` is odd, and `a = r * 2^s`. /// Takes a Jacobi symbol value, and returns `r` and the new Jacobi symbol, /// negated if the transformation changes parity. /// /// Note that the returned `r` is odd. -fn reduce_numerator(j: JacobiSymbol, a: Word, p: &V) -> (JacobiSymbol, Word) { - let p_mod_8 = p.mod8(); +fn apply_reduce_numerator(j: JacobiSymbol, a: Word, p: Word) -> (JacobiSymbol, Word) { + let p_mod_8 = p & 7; let s = a.trailing_zeros(); let j = if (s & 1) == 1 && (p_mod_8 == 3 || p_mod_8 == 5) { -j @@ -60,23 +36,40 @@ fn reduce_numerator(j: JacobiSymbol, a: Word, p: &V) -> (JacobiSymb (j, a >> s) } +fn reduce_numerator_long(j: JacobiSymbol, a: Word, p: &T) -> (JacobiSymbol, Word) { + apply_reduce_numerator(j, a, p.as_ref()[0].0) +} + +fn reduce_numerator_short(j: JacobiSymbol, a: Word, p: Word) -> (JacobiSymbol, Word) { + apply_reduce_numerator(j, a, p) +} + /// Transforms `(a/p)` -> `(p/a)` for odd and coprime `a` and `p`. /// Takes a Jacobi symbol value, and returns the swapped pair and the new Jacobi symbol, /// negated if the transformation changes parity. -fn swap(j: JacobiSymbol, a: T, p: V) -> (JacobiSymbol, V, T) { - let j = if a.mod4() == 1 || p.mod4() == 1 { +fn apply_swap(j: JacobiSymbol, a: Word, p: Word) -> JacobiSymbol { + if a & 3 == 1 || p & 3 == 1 { j } else { -j - }; + } +} + +fn swap_long(j: JacobiSymbol, a: Word, p: &Odd) -> (JacobiSymbol, &Odd, Word) { + let j = apply_swap(j, a, p.as_ref().as_ref()[0].0); + (j, p, a) +} + +fn swap_short(j: JacobiSymbol, a: Word, p: Word) -> (JacobiSymbol, Word, Word) { + let j = apply_swap(j, a, p); (j, p, a) } /// Returns the Jacobi symbol `(a/p)` given an odd `p`. -pub(crate) fn jacobi_symbol_vartime( +pub(crate) fn jacobi_symbol_vartime( abs_a: Word, a_is_negative: bool, - p_long: &Odd>, + p_long: &Odd, ) -> JacobiSymbol { let result = JacobiSymbol::One; // Keep track of all the sign flips here. @@ -84,14 +77,14 @@ pub(crate) fn jacobi_symbol_vartime( // (-a/n) = (-1/n) * (a/n) // = (-1)^((n-1)/2) * (a/n) // = (-1 if n = 3 mod 4 else 1) * (a/n) - let result = if a_is_negative && p_long.mod4() == 3 { + let result = if a_is_negative && p_long.as_ref().as_ref()[0].0 & 3 == 3 { -result } else { result }; // A degenerate case. - if abs_a == 1 || p_long.as_ref() == &Uint::::ONE { + if abs_a == 1 || p_long.as_ref() == &T::one() { return result; } @@ -100,14 +93,14 @@ pub(crate) fn jacobi_symbol_vartime( // Normalize input: at the end we want `a < p`, `p` odd, and both fitting into a `Word`. let (result, a, p): (JacobiSymbol, Word, Word) = if p_long.bits_vartime() <= Limb::BITS { let a = a_limb.0; - let p = p_long.as_limbs()[0].0; + let p = p_long.as_ref().as_ref()[0].0; (result, a % p, p) } else { - let (result, a) = reduce_numerator(result, a_limb.0, p_long.as_ref()); + let (result, a) = reduce_numerator_long(result, a_limb.0, p_long.as_ref()); if a == 1 { return result; } - let (result, a_long, p) = swap(result, a, p_long.get()); + let (result, a_long, p) = swap_long(result, a, p_long); // Can unwrap here, since `p` is swapped with `a`, // and `a` would be odd after `reduce_numerator()`. let a = @@ -127,7 +120,7 @@ pub(crate) fn jacobi_symbol_vartime( // At this point `p` is odd (either coming from outside of the `loop`, // or from the previous iteration, where a previously reduced `a` // was swapped into its place), so we can call this. - (result, a) = reduce_numerator(result, a, &p); + (result, a) = reduce_numerator_short(result, a, p); if a == 1 { return result; @@ -138,7 +131,7 @@ pub(crate) fn jacobi_symbol_vartime( // Note that technically `swap()` only returns a valid `result` if `a` and `p` are coprime. // But if they are not, we will return `Zero` eventually, // which is not affected by any sign changes. - (result, a, p) = swap(result, a, p); + (result, a, p) = swap_short(result, a, p); a %= p; } diff --git a/src/hazmat/lucas.rs b/src/hazmat/lucas.rs index d96676a..9430ec4 100644 --- a/src/hazmat/lucas.rs +++ b/src/hazmat/lucas.rs @@ -1,14 +1,12 @@ //! Lucas primality test. -use crypto_bigint::{ - modular::{MontyForm, MontyParams}, - CheckedAdd, Integer, Odd, Uint, Word, -}; +use crypto_bigint::{Integer, Monty, Odd, Square, Word}; use super::{ gcd::gcd_vartime, jacobi::{jacobi_symbol_vartime, JacobiSymbol}, Primality, }; +use crate::UintLike; /// The maximum number of attempts to find `D` such that `(D/n) == -1`. // This is widely believed to be impossible. @@ -30,7 +28,7 @@ pub trait LucasBase { /// Given an odd integer, returns `Ok((P, abs(Q), is_negative(Q)))` on success, /// or `Err(Primality)` if the primality for the given integer was discovered /// during the search for a base. - fn generate(&self, n: &Odd>) -> Result<(Word, Word, bool), Primality>; + fn generate(&self, n: &Odd) -> Result<(Word, Word, bool), Primality>; } /// "Method A" for selecting the base given in Baillie & Wagstaff[^Baillie1980], @@ -48,11 +46,11 @@ pub trait LucasBase { pub struct SelfridgeBase; impl LucasBase for SelfridgeBase { - fn generate(&self, n: &Odd>) -> Result<(Word, Word, bool), Primality> { + fn generate(&self, n: &Odd) -> Result<(Word, Word, bool), Primality> { let mut abs_d = 5; let mut d_is_negative = false; let n_is_small = n.bits_vartime() < Word::BITS; // if true, `n` fits into one `Word` - let small_n = n.as_words()[0]; + let small_n = n.as_ref().as_ref()[0].0; let mut attempts = 0; loop { if attempts >= MAX_ATTEMPTS { @@ -61,7 +59,7 @@ impl LucasBase for SelfridgeBase { if attempts >= ATTEMPTS_BEFORE_SQRT { let sqrt_n = n.sqrt_vartime(); - if &sqrt_n.wrapping_mul(&sqrt_n) == n { + if &sqrt_n.wrapping_mul(&sqrt_n) == n.as_ref() { return Err(Primality::Composite); } } @@ -113,7 +111,7 @@ impl LucasBase for SelfridgeBase { pub struct AStarBase; impl LucasBase for AStarBase { - fn generate(&self, n: &Odd>) -> Result<(Word, Word, bool), Primality> { + fn generate(&self, n: &Odd) -> Result<(Word, Word, bool), Primality> { SelfridgeBase.generate(n).map(|(p, abs_q, q_is_negative)| { if abs_q == 1 && q_is_negative { (5, 5, false) @@ -136,7 +134,7 @@ impl LucasBase for AStarBase { pub struct BruteForceBase; impl LucasBase for BruteForceBase { - fn generate(&self, n: &Odd>) -> Result<(Word, Word, bool), Primality> { + fn generate(&self, n: &Odd) -> Result<(Word, Word, bool), Primality> { let mut p = 3; let mut attempts = 0; @@ -147,7 +145,7 @@ impl LucasBase for BruteForceBase { if attempts >= ATTEMPTS_BEFORE_SQRT { let sqrt_n = n.sqrt_vartime(); - if &sqrt_n.wrapping_mul(&sqrt_n) == n { + if &sqrt_n.wrapping_mul(&sqrt_n) == n.as_ref() { return Err(Primality::Composite); } } @@ -164,7 +162,7 @@ impl LucasBase for BruteForceBase { // Since the loop proceeds in increasing P and starts with P - 2 == 1, // the shared prime factor must be P + 2. // If P + 2 == n, then n is prime; otherwise P + 2 is a proper factor of n. - let primality = if n.as_ref() == &Uint::::from(p + 2) { + let primality = if n.as_ref() == &T::from(p + 2) { Primality::Prime } else { Primality::Composite @@ -181,22 +179,22 @@ impl LucasBase for BruteForceBase { } /// For the given odd `n`, finds `s` and odd `d` such that `n + 1 == 2^s * d`. -fn decompose(n: &Odd>) -> (u32, Odd>) { +fn decompose(n: &Odd) -> (u32, Odd) { // Need to be careful here since `n + 1` can overflow. // Instead of adding 1 and counting trailing 0s, we count trailing ones on the original `n`. - let s = n.trailing_ones(); + let s = n.trailing_ones_vartime(); let d = if s < n.bits_precision() { // The shift won't overflow because of the check above. // The addition won't overflow since the original `n` was odd, // so we right-shifted at least once. n.as_ref() - .overflowing_shr(s) + .overflowing_shr_vartime(s) .expect("shift should be within range by construction") - .checked_add(&Uint::ONE) + .checked_add(&T::one()) .expect("addition should not overflow by construction") } else { - Uint::ONE + T::one() }; (s, Odd::new(d).expect("`d` should be odd by construction")) @@ -285,8 +283,8 @@ pub enum LucasCheck { /// Performs the primality test based on Lucas sequence. /// See [`LucasCheck`] for possible checks, and the implementors of [`LucasBase`] /// for the corresponding bases. -pub fn lucas_test( - candidate: &Odd>, +pub fn lucas_test( + candidate: &Odd, base: impl LucasBase, check: LucasCheck, ) -> Primality { @@ -330,8 +328,8 @@ pub fn lucas_test( // we check that gcd(n, Q) = 1 anyway - again, since `Q` is small, // it does not noticeably affect the performance. if abs_q != 1 - && gcd_vartime(candidate, abs_q) != 1 - && candidate.as_ref() > &Uint::::from(abs_q) + && gcd_vartime(candidate.as_ref(), abs_q) != 1 + && candidate.as_ref() > &T::from(abs_q) { return Primality::Composite; } @@ -342,19 +340,19 @@ pub fn lucas_test( // Some constants in Montgomery form - let params = MontyParams::::new(*candidate); + let params = ::Monty::new_params(candidate.clone()); - let zero = MontyForm::::zero(params); - let one = MontyForm::::one(params); - let two = one + one; - let minus_two = -two; + let zero = ::Monty::zero(params.clone()); + let one = ::Monty::one(params.clone()); + let two = one.clone() + &one; + let minus_two = -two.clone(); // Convert Q to Montgomery form let q = if q_is_one { - one + one.clone() } else { - let abs_q = MontyForm::::new(&Uint::::from(abs_q), params); + let abs_q = ::Monty::new(T::from(abs_q), params.clone()); if q_is_negative { -abs_q } else { @@ -365,9 +363,9 @@ pub fn lucas_test( // Convert P to Montgomery form let p = if p_is_one { - one + one.clone() } else { - MontyForm::::new(&Uint::::from(p), params) + ::Monty::new(T::from(p), params.clone()) }; // Compute d-th element of Lucas sequence (U_d(P, Q), V_d(P, Q)), where: @@ -385,20 +383,20 @@ pub fn lucas_test( // We can therefore start with k=0 and build up to k=d in log2(d) steps. // Starting with k = 0 - let mut vk = two; // keeps V_k - let mut uk = MontyForm::::zero(params); // keeps U_k - let mut qk = one; // keeps Q^k + let mut vk = two.clone(); // keeps V_k + let mut uk = ::Monty::zero(params.clone()); // keeps U_k + let mut qk = one.clone(); // keeps Q^k // D in Montgomery representation - note that it can be negative. - let abs_d = MontyForm::::new(&Uint::::from(abs_d), params); + let abs_d = ::Monty::new(T::from(abs_d), params); let d_m = if d_is_negative { -abs_d } else { abs_d }; for i in (0..d.bits_vartime()).rev() { // k' = k * 2 - let u_2k = uk * vk; - let v_2k = vk.square() - (qk + qk); let q_2k = qk.square(); + let u_2k = uk * &vk; + let v_2k = vk.square() - &(qk.clone() + &qk); uk = u_2k; vk = v_2k; @@ -407,11 +405,15 @@ pub fn lucas_test( if d.bit_vartime(i) { // k' = k + 1 - let (p_uk, p_vk) = if p_is_one { (uk, vk) } else { (p * uk, p * vk) }; + let (p_uk, p_vk) = if p_is_one { + (uk.clone(), vk.clone()) + } else { + (p.clone() * &uk, p.clone() * &vk) + }; - let u_k1 = (p_uk + vk).div_by_2(); - let v_k1 = (d_m * uk + p_vk).div_by_2(); - let q_k1 = qk * q; + let u_k1 = (p_uk + &vk).div_by_2(); + let v_k1 = (d_m.clone() * &uk + &p_vk).div_by_2(); + let q_k1 = qk * &q; uk = u_k1; vk = v_k1; @@ -469,7 +471,7 @@ pub fn lucas_test( // k' = 2k // V_{k'} = V_k^2 - 2 Q^k - vk = vk * vk - qk - qk; + vk = vk.square() - &qk - &qk; if check != LucasCheck::LucasV && vk == zero { return Primality::ProbablyPrime; @@ -483,10 +485,10 @@ pub fn lucas_test( if check == LucasCheck::LucasV { // At this point vk = V_{d * 2^(s-1)}. // Double the index again: - vk = vk * vk - qk - qk; // now vk = V_{d * 2^s} = V_{n+1} + vk = vk.square() - &qk - &qk; // now vk = V_{d * 2^s} = V_{n+1} // Lucas-V check[^Baillie2021]: if V_{n+1} == 2 Q, report `n` as prime. - if vk == q + q { + if vk == q.clone() + &q { return Primality::ProbablyPrime; } } @@ -507,7 +509,10 @@ mod tests { use super::{ decompose, lucas_test, AStarBase, BruteForceBase, LucasBase, LucasCheck, SelfridgeBase, }; - use crate::hazmat::{primes, pseudoprimes, Primality}; + use crate::{ + hazmat::{primes, pseudoprimes, Primality}, + UintLike, + }; #[test] fn bases_derived_traits() { @@ -552,10 +557,7 @@ mod tests { struct TestBase; impl LucasBase for TestBase { - fn generate( - &self, - _n: &Odd>, - ) -> Result<(Word, Word, bool), Primality> { + fn generate(&self, _n: &Odd) -> Result<(Word, Word, bool), Primality> { Ok((5, 5, false)) } } diff --git a/src/hazmat/miller_rabin.rs b/src/hazmat/miller_rabin.rs index faf1156..bc0ca07 100644 --- a/src/hazmat/miller_rabin.rs +++ b/src/hazmat/miller_rabin.rs @@ -1,13 +1,10 @@ //! Miller-Rabin primality test. +use crypto_bigint::{Integer, Monty, NonZero, Odd, PowBoundedExp, Square}; use rand_core::CryptoRngCore; -use crypto_bigint::{ - modular::{MontyForm, MontyParams}, - CheckedAdd, NonZero, Odd, RandomMod, Uint, -}; - use super::Primality; +use crate::UintLike; /// Precomputed data used to perform Miller-Rabin primality test[^Pomerance1980]. /// The numbers that pass it are commonly called "strong probable primes" @@ -17,31 +14,29 @@ use super::Primality; /// C. Pomerance, J. L. Selfridge, S. S. Wagstaff "The Pseudoprimes to 25*10^9", /// Math. Comp. 35 1003-1026 (1980), /// DOI: [10.2307/2006210](https://dx.doi.org/10.2307/2006210) -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct MillerRabin { - candidate: Uint, +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MillerRabin { + candidate: T, bit_length: u32, - montgomery_params: MontyParams, - one: MontyForm, - minus_one: MontyForm, + montgomery_params: <::Monty as Monty>::Params, + one: ::Monty, + minus_one: ::Monty, s: u32, - d: Uint, + d: T, } -impl MillerRabin { +impl MillerRabin { /// Initializes a Miller-Rabin test for `candidate`. - /// - /// Panics if `candidate` is even. - pub fn new(candidate: Odd>) -> Self { - let params = MontyParams::::new(candidate); - let one = MontyForm::::one(params); - let minus_one = -one; + pub fn new(candidate: Odd) -> Self { + let params = ::Monty::new_params(candidate.clone()); + let one = ::Monty::one(params.clone()); + let minus_one = -one.clone(); // Find `s` and odd `d` such that `candidate - 1 == 2^s * d`. - let (s, d) = if candidate.as_ref() == &Uint::ONE { - (0, Uint::ONE) + let (s, d) = if candidate.as_ref() == &T::one() { + (0, T::one()) } else { - let candidate_minus_one = candidate.wrapping_sub(&Uint::ONE); + let candidate_minus_one = candidate.wrapping_sub(&T::one()); let s = candidate_minus_one.trailing_zeros_vartime(); // Will not overflow because `candidate` is odd and greater than 1. let d = candidate_minus_one @@ -51,7 +46,7 @@ impl MillerRabin { }; Self { - candidate: *candidate, + candidate: candidate.as_ref().clone(), bit_length: candidate.bits_vartime(), montgomery_params: params, one, @@ -62,11 +57,11 @@ impl MillerRabin { } /// Perform a Miller-Rabin check with a given base. - pub fn test(&self, base: &Uint) -> Primality { + pub fn test(&self, base: &T) -> Primality { // TODO: it may be faster to first check that gcd(base, candidate) == 1, // otherwise we can return `Composite` right away. - let base = MontyForm::::new(base, self.montgomery_params); + let base = ::Monty::new(base.clone(), self.montgomery_params.clone()); // Implementation detail: bounded exp gets faster every time we decrease the bound // by the window length it uses, which is currently 4 bits. @@ -91,7 +86,7 @@ impl MillerRabin { /// Perform a Miller-Rabin check with base 2. pub fn test_base_two(&self) -> Primality { - self.test(&Uint::::from(2u32)) + self.test(&T::from(2u32)) } /// Perform a Miller-Rabin check with a random base (in the range `[3, candidate-2]`) @@ -106,14 +101,14 @@ impl MillerRabin { panic!("No suitable random base possible when `candidate == 3`; use the base 2 test.") } - let range = self.candidate.wrapping_sub(&Uint::::from(4u32)); + let range = self.candidate.wrapping_sub(&T::from(4u32)); // Can unwrap here since `candidate` is odd, and `candidate >= 4` (as checked above) let range_nonzero = NonZero::new(range).expect("the range should be non-zero by construction"); // This should not overflow as long as `random_mod()` behaves according to the contract // (that is, returns a number within the given range). - let random = Uint::::random_mod(rng, &range_nonzero) - .checked_add(&Uint::::from(3u32)) + let random = T::random_mod(rng, &range_nonzero) + .checked_add(&T::from(3u32)) .expect("addition should not overflow by construction"); self.test(&random) } @@ -132,7 +127,10 @@ mod tests { use num_prime::nt_funcs::is_prime64; use super::MillerRabin; - use crate::hazmat::{primes, pseudoprimes, random_odd_uint, Sieve}; + use crate::{ + hazmat::{primes, pseudoprimes, random_odd_uint, Sieve}, + UintLike, + }; #[test] fn miller_rabin_derived_traits() { @@ -154,9 +152,9 @@ mod tests { pseudoprimes::STRONG_BASE_2.iter().any(|x| *x == num) } - fn random_checks( + fn random_checks( rng: &mut impl CryptoRngCore, - mr: &MillerRabin, + mr: &MillerRabin, count: usize, ) -> usize { (0..count) @@ -194,8 +192,8 @@ mod tests { #[test] fn trivial() { let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901"); - let start: Odd = random_odd_uint(&mut rng, 1024); - for num in Sieve::new(&start, 1024, false).take(10) { + let start = random_odd_uint::(&mut rng, 1024); + for num in Sieve::new(start.as_ref(), 1024, false).take(10) { let mr = MillerRabin::new(Odd::new(num).unwrap()); // Trivial tests, must always be true. @@ -209,9 +207,9 @@ mod tests { let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901"); // Mersenne prime 2^127-1 - let num = U128::from_be_hex("7fffffffffffffffffffffffffffffff"); + let num = Odd::new(U128::from_be_hex("7fffffffffffffffffffffffffffffff")).unwrap(); - let mr = MillerRabin::new(Odd::new(num).unwrap()); + let mr = MillerRabin::new(num); assert!(mr.test_base_two().is_probably_prime()); for _ in 0..10 { assert!(mr.test_random_base(&mut rng).is_probably_prime()); diff --git a/src/hazmat/sieve.rs b/src/hazmat/sieve.rs index 8b10edd..f624d1f 100644 --- a/src/hazmat/sieve.rs +++ b/src/hazmat/sieve.rs @@ -3,44 +3,41 @@ use alloc::{vec, vec::Vec}; -use crypto_bigint::{CheckedAdd, Odd, Random, Uint}; +use crypto_bigint::Odd; use rand_core::CryptoRngCore; -use crate::hazmat::precomputed::{SmallPrime, RECIPROCALS, SMALL_PRIMES}; +use crate::{ + hazmat::precomputed::{SmallPrime, RECIPROCALS, SMALL_PRIMES}, + UintLike, +}; /// Returns a random odd integer with given bit length /// (that is, with both `0` and `bit_length-1` bits set). /// /// Panics if `bit_length` is 0 or is greater than the bit size of the target `Uint`. -pub fn random_odd_uint( - rng: &mut impl CryptoRngCore, - bit_length: u32, -) -> Odd> { +pub fn random_odd_uint(rng: &mut impl CryptoRngCore, bit_length: u32) -> Odd { if bit_length == 0 { panic!("Bit length must be non-zero"); } - if bit_length > Uint::::BITS { + // TODO: what do we do here if `bit_length` is greater than Uint::BITS? + // assume that the user knows what he's doing since it is a hazmat function? + /*if bit_length > Uint::::BITS { panic!( "The requested bit length ({}) is larger than the chosen Uint size", bit_length ); - } + }*/ // TODO: not particularly efficient, can be improved by zeroing high bits instead of shifting - let mut random = Uint::::random(rng); - if bit_length != Uint::::BITS { - random >>= Uint::::BITS - bit_length; - } + let mut random = T::random_bits(rng, bit_length); // Make it odd - random |= Uint::::ONE; + random.set_bit_vartime(0, true); // Make sure it's the correct bit size // Will not overflow since `bit_length` is ensured to be within the size of the integer. - random |= Uint::::ONE - .overflowing_shl_vartime(bit_length - 1) - .expect("shift should be within range by construction"); + random.set_bit_vartime(bit_length - 1, true); Odd::new(random).expect("the number should be odd by construction") } @@ -56,11 +53,11 @@ const INCR_LIMIT: Residue = Residue::MAX - SMALL_PRIMES[SMALL_PRIMES.len() - 1] /// An iterator returning numbers with up to and including given bit length, /// starting from a given number, that are not multiples of the first 2048 small primes. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Sieve { - // Instead of dividing `Uint` by small primes every time (which is slow), +pub struct Sieve { + // Instead of dividing a big integer by small primes every time (which is slow), // we keep a "base" and a small increment separately, // so that we can only calculate the residues of the increment. - base: Uint, + base: T, incr: Residue, incr_limit: Residue, safe_primes: bool, @@ -71,7 +68,7 @@ pub struct Sieve { last_round: bool, } -impl Sieve { +impl Sieve { /// Creates a new sieve, iterating from `start` and /// until the last number with `max_bit_length` bits, /// producing numbers that are not non-trivial multiples @@ -84,12 +81,12 @@ impl Sieve { /// Panics if `max_bit_length` is zero or greater than the size of the target `Uint`. /// /// If `safe_primes` is `true`, both the returned `n` and `n/2` are sieved. - pub fn new(start: &Uint, max_bit_length: u32, safe_primes: bool) -> Self { + pub fn new(start: &T, max_bit_length: u32, safe_primes: bool) -> Self { if max_bit_length == 0 { panic!("The requested bit length cannot be zero"); } - if max_bit_length > Uint::::BITS { + if max_bit_length > start.bits_precision() { panic!( "The requested bit length ({}) is larger than the chosen Uint size", max_bit_length @@ -101,33 +98,33 @@ impl Sieve { let (max_bit_length, base) = if safe_primes { (max_bit_length - 1, start.wrapping_shr_vartime(1)) } else { - (max_bit_length, *start) + (max_bit_length, start.clone()) }; let mut base = base; // This is easier than making all the methods generic enough to handle these corner cases. - let produces_nothing = max_bit_length < base.bits() || max_bit_length < 2; + let produces_nothing = max_bit_length < base.bits_vartime() || max_bit_length < 2; // Add the exception to the produced candidates - the only one that doesn't fit // the general pattern of incrementing the base by 2. let mut starts_from_exception = false; - if base <= Uint::::from(2u32) { + if base <= T::from(2u32) { starts_from_exception = true; - base = Uint::::from(3u32); + base = T::from(3u32); } else { // Adjust the base so that we hit odd numbers when incrementing it by 2. - base |= Uint::::ONE; + base |= T::one(); } // Only calculate residues by primes up to and not including `base`, // because when we only have the resiude, // we cannot distinguish between a prime itself and a multiple of that prime. - let residues_len = if Uint::::from(SMALL_PRIMES[SMALL_PRIMES.len() - 1]) >= base { + let residues_len = if T::from(SMALL_PRIMES[SMALL_PRIMES.len() - 1]) >= base { SMALL_PRIMES .iter() .enumerate() - .find(|(_i, p)| Uint::::from(**p) >= base) + .find(|(_i, p)| T::from(**p) >= base) .map(|(i, _p)| i) .unwrap_or(SMALL_PRIMES.len()) } else { @@ -175,12 +172,12 @@ impl Sieve { } // Find the increment limit. - let max_value = Uint::::ONE - .overflowing_shl(self.max_bit_length) - .unwrap_or(Uint::ZERO) - .wrapping_sub(&Uint::::ONE); + let max_value = match T::one().overflowing_shl_vartime(self.max_bit_length).into() { + Some(val) => val, + None => T::one(), + }; let incr_limit = max_value.wrapping_sub(&self.base); - self.incr_limit = if incr_limit > Uint::::from(INCR_LIMIT) { + self.incr_limit = if incr_limit > T::from(INCR_LIMIT) { INCR_LIMIT } else { // We are close to `2^max_bit_length - 1`. @@ -188,7 +185,8 @@ impl Sieve { self.last_round = true; // Can unwrap here since we just checked above that `incr_limit <= INCR_LIMIT`, // and `INCR_LIMIT` fits into `Residue`. - let incr_limit_small: Residue = incr_limit.as_words()[0] + let incr_limit_small: Residue = incr_limit.as_ref()[0] + .0 .try_into() .expect("the increment limit should fit within `Residue`"); incr_limit_small @@ -223,19 +221,19 @@ impl Sieve { // Returns the restored `base + incr` if it is not composite (wrt the small primes), // and bumps the increment unconditionally. - fn maybe_next(&mut self) -> Option> { + fn maybe_next(&mut self) -> Option { let result = if self.current_is_composite() { None } else { // The overflow should never happen here since `incr` // is never greater than `incr_limit`, and the latter is chosen such that // it does not overflow when added to `base` (see `update_residues()`). - let mut num: Uint = self + let mut num = self .base .checked_add(&self.incr.into()) .expect("addition should not overflow by construction"); if self.safe_primes { - num = num.wrapping_shl_vartime(1) | Uint::::ONE; + num = num.wrapping_shl_vartime(1) | T::one(); } Some(num) }; @@ -244,7 +242,7 @@ impl Sieve { result } - fn next(&mut self) -> Option> { + fn next(&mut self) -> Option { // Corner cases handled here if self.produces_nothing { @@ -253,7 +251,7 @@ impl Sieve { if self.starts_from_exception { self.starts_from_exception = false; - return Some(Uint::::from(if self.safe_primes { 5u32 } else { 2u32 })); + return Some(T::from(if self.safe_primes { 5u32 } else { 2u32 })); } // Main loop @@ -268,8 +266,8 @@ impl Sieve { } } -impl Iterator for Sieve { - type Item = Uint; +impl Iterator for Sieve { + type Item = T; fn next(&mut self) -> Option { Self::next(self) @@ -282,7 +280,7 @@ mod tests { use alloc::format; use alloc::vec::Vec; - use crypto_bigint::{Odd, U64}; + use crypto_bigint::U64; use num_prime::nt_funcs::factorize64; use rand_chacha::ChaCha8Rng; use rand_core::{OsRng, SeedableRng}; @@ -295,9 +293,9 @@ mod tests { let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1]; let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901"); - let start: Odd = random_odd_uint(&mut rng, 32); + let start = random_odd_uint::(&mut rng, 32).get(); for num in Sieve::new(&start, 32, false).take(100) { - let num_u64: u64 = num.into(); + let num_u64 = u64::from(num); assert!(num_u64.leading_zeros() == 32); let factors_and_powers = factorize64(num_u64); @@ -375,7 +373,7 @@ mod tests { #[test] fn random_below_max_length() { for _ in 0..10 { - let r: Odd = random_odd_uint(&mut OsRng, 50); + let r = random_odd_uint::(&mut OsRng, 50).get(); assert_eq!(r.bits(), 50); } } @@ -383,13 +381,13 @@ mod tests { #[test] #[should_panic(expected = "Bit length must be non-zero")] fn random_odd_uint_0bits() { - let _p: Odd = random_odd_uint(&mut OsRng, 0); + let _p = random_odd_uint::(&mut OsRng, 0); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn random_odd_uint_too_many_bits() { - let _p: Odd = random_odd_uint(&mut OsRng, 65); + let _p = random_odd_uint::(&mut OsRng, 65); } #[test] diff --git a/src/lib.rs b/src/lib.rs index f708159..03a4640 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ extern crate alloc; pub mod hazmat; mod presets; mod traits; +mod uint_traits; pub use presets::{ generate_prime_with_rng, generate_safe_prime_with_rng, is_prime_with_rng, @@ -27,3 +28,5 @@ pub use traits::RandomPrimeWithRng; #[cfg(feature = "default-rng")] pub use presets::{generate_prime, generate_safe_prime, is_prime, is_safe_prime}; + +pub use uint_traits::UintLike; diff --git a/src/presets.rs b/src/presets.rs index 5b880ab..7232dbe 100644 --- a/src/presets.rs +++ b/src/presets.rs @@ -1,4 +1,4 @@ -use crypto_bigint::{Odd, Uint}; +use crypto_bigint::Odd; use rand_core::CryptoRngCore; #[cfg(feature = "default-rng")] @@ -7,13 +7,14 @@ use rand_core::OsRng; use crate::hazmat::{ lucas_test, random_odd_uint, AStarBase, LucasCheck, MillerRabin, Primality, Sieve, }; +use crate::UintLike; /// Returns a random prime of size `bit_length` using [`OsRng`] as the RNG. /// If `bit_length` is `None`, the full size of `Uint` is used. /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn generate_prime(bit_length: Option) -> Uint { +pub fn generate_prime(bit_length: u32) -> T { generate_prime_with_rng(&mut OsRng, bit_length) } @@ -23,7 +24,7 @@ pub fn generate_prime(bit_length: Option) -> Uint { /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn generate_safe_prime(bit_length: Option) -> Uint { +pub fn generate_safe_prime(bit_length: u32) -> T { generate_safe_prime_with_rng(&mut OsRng, bit_length) } @@ -31,7 +32,7 @@ pub fn generate_safe_prime(bit_length: Option) -> Uint { /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn is_prime(num: &Uint) -> bool { +pub fn is_prime(num: &T) -> bool { is_prime_with_rng(&mut OsRng, num) } @@ -41,7 +42,7 @@ pub fn is_prime(num: &Uint) -> bool { /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn is_safe_prime(num: &Uint) -> bool { +pub fn is_safe_prime(num: &T) -> bool { is_safe_prime_with_rng(&mut OsRng, num) } @@ -51,17 +52,13 @@ pub fn is_safe_prime(num: &Uint) -> bool { /// Panics if `bit_length` is less than 2, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. -pub fn generate_prime_with_rng( - rng: &mut impl CryptoRngCore, - bit_length: Option, -) -> Uint { - let bit_length = bit_length.unwrap_or(Uint::::BITS); +pub fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> T { if bit_length < 2 { panic!("`bit_length` must be 2 or greater."); } loop { - let start = random_odd_uint::(rng, bit_length); - let sieve = Sieve::new(&start, bit_length, false); + let start = random_odd_uint::(rng, bit_length); + let sieve = Sieve::new(start.as_ref(), bit_length, false); for num in sieve { if is_prime_with_rng(rng, &num) { return num; @@ -77,17 +74,16 @@ pub fn generate_prime_with_rng( /// Panics if `bit_length` is less than 3, or is greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. -pub fn generate_safe_prime_with_rng( +pub fn generate_safe_prime_with_rng( rng: &mut impl CryptoRngCore, - bit_length: Option, -) -> Uint { - let bit_length = bit_length.unwrap_or(Uint::::BITS); + bit_length: u32, +) -> T { if bit_length < 3 { panic!("`bit_length` must be 3 or greater."); } loop { - let start = random_odd_uint::(rng, bit_length); - let sieve = Sieve::new(&start, bit_length, true); + let start = random_odd_uint::(rng, bit_length); + let sieve = Sieve::new(start.as_ref(), bit_length, true); for num in sieve { if is_safe_prime_with_rng(rng, &num) { return num; @@ -121,12 +117,12 @@ pub fn generate_safe_prime_with_rng( /// "Strengthening the Baillie-PSW primality test", /// Math. Comp. 90 1931-1955 (2021), /// DOI: [10.1090/mcom/3616](https://doi.org/10.1090/mcom/3616) -pub fn is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Uint) -> bool { - if num == &Uint::::from(2u32) { +pub fn is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &T) -> bool { + if num == &T::from(2u32) { return true; } - let odd_num = match Odd::new(*num).into() { + let odd_num = match Odd::new(num.clone()).into() { Some(x) => x, None => return false, }; @@ -137,27 +133,27 @@ pub fn is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Uin /// Checks probabilistically if the given number is a safe prime using the provided RNG. /// /// See [`is_prime_with_rng`] for details about the performed checks. -pub fn is_safe_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Uint) -> bool { +pub fn is_safe_prime_with_rng(rng: &mut impl CryptoRngCore, num: &T) -> bool { // Since, by the definition of safe prime, `(num - 1) / 2` must also be prime, // and therefore odd, `num` has to be equal to 3 modulo 4. // 5 is the only exception, so we check for it. - if num == &Uint::::from(5u32) { + if num == &T::from(5u32) { return true; } - if num.as_words()[0] & 3 != 3 { + if num.as_ref()[0].0 & 3 != 3 { return false; } // These are ensured to be odd by the check above. - let odd_num = Odd::new(*num).expect("`num` should be odd here"); + let odd_num = Odd::new(num.clone()).expect("`num` should be odd here"); let odd_half_num = Odd::new(num.wrapping_shr_vartime(1)).expect("`num/2` should be odd here"); _is_prime_with_rng(rng, &odd_num) && _is_prime_with_rng(rng, &odd_half_num) } /// Checks for primality assuming that `num` is odd. -fn _is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Odd>) -> bool { - let mr = MillerRabin::new(*num); +fn _is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Odd) -> bool { + let mr = MillerRabin::new(num.clone()); if !mr.test_base_two().is_probably_prime() { return false; @@ -238,8 +234,7 @@ mod tests { } next = next - .overflowing_shl_vartime(1) - .unwrap() + .wrapping_shl_vartime(1) .checked_add(&Uint::::ONE) .unwrap(); } @@ -261,7 +256,7 @@ mod tests { #[test] fn prime_generation() { for bit_length in (28..=128).step_by(10) { - let p: U128 = generate_prime(Some(bit_length)); + let p: U128 = generate_prime(bit_length); assert!(p.bits_vartime() == bit_length); assert!(is_prime(&p)); } @@ -270,7 +265,7 @@ mod tests { #[test] fn safe_prime_generation() { for bit_length in (28..=128).step_by(10) { - let p: U128 = generate_safe_prime(Some(bit_length)); + let p: U128 = generate_safe_prime(bit_length); assert!(p.bits_vartime() == bit_length); assert!(is_safe_prime(&p)); } @@ -303,25 +298,25 @@ mod tests { #[test] #[should_panic(expected = "`bit_length` must be 2 or greater")] fn generate_prime_too_few_bits() { - let _p: U64 = generate_prime_with_rng(&mut OsRng, Some(1)); + let _p: U64 = generate_prime_with_rng(&mut OsRng, 1); } #[test] #[should_panic(expected = "`bit_length` must be 3 or greater")] fn generate_safe_prime_too_few_bits() { - let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, Some(2)); + let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 2); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn generate_prime_too_many_bits() { - let _p: U64 = generate_prime_with_rng(&mut OsRng, Some(65)); + let _p: U64 = generate_prime_with_rng(&mut OsRng, 65); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn generate_safe_prime_too_many_bits() { - let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, Some(65)); + let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 65); } fn is_prime_ref(num: Word) -> bool { @@ -332,7 +327,7 @@ mod tests { fn corner_cases_generate_prime() { for bits in 2..5 { for _ in 0..100 { - let p: U64 = generate_prime(Some(bits)); + let p: U64 = generate_prime(bits); let p_word = p.as_words()[0]; assert!(is_prime_ref(p_word)); } @@ -343,7 +338,7 @@ mod tests { fn corner_cases_generate_safe_prime() { for bits in 3..5 { for _ in 0..100 { - let p: U64 = generate_safe_prime(Some(bits)); + let p: U64 = generate_safe_prime(bits); let p_word = p.as_words()[0]; assert!(is_prime_ref(p_word) && is_prime_ref(p_word / 2)); } @@ -356,7 +351,7 @@ mod tests { mod tests_openssl { use alloc::format; - use crypto_bigint::{Odd, U128}; + use crypto_bigint::U128; use openssl::bn::{BigNum, BigNumContext}; use rand_core::OsRng; @@ -381,7 +376,7 @@ mod tests_openssl { // Generate primes, let OpenSSL check them for _ in 0..100 { - let p: U128 = generate_prime(Some(128)); + let p: U128 = generate_prime(128); let p_bn = to_openssl(&p); assert!( openssl_is_prime(&p_bn, &mut ctx), @@ -399,8 +394,8 @@ mod tests_openssl { // Generate random numbers, check if our test agrees with OpenSSL for _ in 0..100 { - let p: Odd = random_odd_uint(&mut OsRng, 128); - let actual = is_prime(&p); + let p = random_odd_uint::(&mut OsRng, 128); + let actual = is_prime(p.as_ref()); let p_bn = to_openssl(&p); let expected = openssl_is_prime(&p_bn, &mut ctx); assert_eq!( @@ -414,7 +409,7 @@ mod tests_openssl { #[cfg(test)] #[cfg(feature = "tests-gmp")] mod tests_gmp { - use crypto_bigint::{Odd, U128}; + use crypto_bigint::U128; use rand_core::OsRng; use rug::{ integer::{IsPrime, Order}, @@ -440,14 +435,14 @@ mod tests_gmp { fn gmp_cross_check() { // Generate primes, let GMP check them for _ in 0..100 { - let p: U128 = generate_prime(Some(128)); + let p: U128 = generate_prime(128); let p_bn = to_gmp(&p); assert!(gmp_is_prime(&p_bn), "GMP reports {p} as composite"); } // Generate primes with GMP, check them for _ in 0..100 { - let start: Odd = random_odd_uint(&mut OsRng, 128); + let start = random_odd_uint::(&mut OsRng, 128); let start_bn = to_gmp(&start); let p_bn = start_bn.next_prime(); let p = from_gmp(&p_bn); @@ -456,8 +451,8 @@ mod tests_gmp { // Generate random numbers, check if our test agrees with GMP for _ in 0..100 { - let p: Odd = random_odd_uint(&mut OsRng, 128); - let actual = is_prime(&p); + let p = random_odd_uint::(&mut OsRng, 128); + let actual = is_prime(p.as_ref()); let p_bn = to_gmp(&p); let expected = gmp_is_prime(&p_bn); assert_eq!( diff --git a/src/traits.rs b/src/traits.rs index bb72242..458d85c 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -15,7 +15,7 @@ pub trait RandomPrimeWithRng { /// Panics if `bit_length` is less than 2, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. - fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: Option) -> Self; + fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; /// Returns a random safe prime (that is, such that `(n - 1) / 2` is also prime) /// of size `bit_length` using the provided RNG. @@ -24,7 +24,7 @@ pub trait RandomPrimeWithRng { /// Panics if `bit_length` is less than 3, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. - fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: Option) -> Self; + fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; /// Checks probabilistically if the given number is prime using the provided RNG. /// @@ -38,10 +38,10 @@ pub trait RandomPrimeWithRng { } impl RandomPrimeWithRng for Uint { - fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: Option) -> Self { + fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { generate_prime_with_rng(rng, bit_length) } - fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: Option) -> Self { + fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { generate_safe_prime_with_rng(rng, bit_length) } fn is_prime_with_rng(&self, rng: &mut impl CryptoRngCore) -> bool { @@ -67,8 +67,9 @@ mod tests { assert!(!U64::from(13u32).is_safe_prime_with_rng(&mut OsRng)); assert!(U64::from(11u32).is_safe_prime_with_rng(&mut OsRng)); - assert!(U64::generate_prime_with_rng(&mut OsRng, Some(10)).is_prime_with_rng(&mut OsRng)); - assert!(U64::generate_safe_prime_with_rng(&mut OsRng, Some(10)) - .is_safe_prime_with_rng(&mut OsRng)); + assert!(U64::generate_prime_with_rng(&mut OsRng, 10).is_prime_with_rng(&mut OsRng)); + assert!( + U64::generate_safe_prime_with_rng(&mut OsRng, 10).is_safe_prime_with_rng(&mut OsRng) + ); } } diff --git a/src/uint_traits.rs b/src/uint_traits.rs new file mode 100644 index 0000000..015ba4e --- /dev/null +++ b/src/uint_traits.rs @@ -0,0 +1,120 @@ +#![allow(missing_docs)] + +use crypto_bigint::{subtle::CtOption, BoxedUint, Integer, Random, RandomMod, Uint}; +use rand_core::CryptoRngCore; + +// would be nice to have: *Assign traits; arithmetic traits for &self (BitAnd and Shr in particular); +pub trait UintLike: Integer + RandomMod { + fn bit_vartime(&self, index: u32) -> bool; + fn set_bit_vartime(&mut self, index: u32, value: bool); + fn trailing_zeros_vartime(&self) -> u32; + fn trailing_ones_vartime(&self) -> u32; + fn sqrt_vartime(&self) -> Self; + fn overflowing_shl_vartime(&self, shift: u32) -> CtOption; + fn overflowing_shr_vartime(&self, shift: u32) -> CtOption; + fn wrapping_shl_vartime(&self, shift: u32) -> Self; + fn wrapping_shr_vartime(&self, shift: u32) -> Self; + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; +} + +// Uint impls + +impl UintLike for Uint { + fn set_bit_vartime(&mut self, index: u32, value: bool) { + if value { + *self |= Uint::ONE << index + } else { + *self &= (Uint::ONE << index).not() + } + } + + fn trailing_zeros_vartime(&self) -> u32 { + Self::trailing_zeros_vartime(self) + } + + fn trailing_ones_vartime(&self) -> u32 { + Self::trailing_ones_vartime(self) + } + + fn bit_vartime(&self, index: u32) -> bool { + Self::bit_vartime(self, index) + } + + fn sqrt_vartime(&self) -> Self { + Self::sqrt_vartime(self) + } + + fn overflowing_shl_vartime(&self, shift: u32) -> CtOption { + Self::overflowing_shl_vartime(self, shift).into() + } + + fn overflowing_shr_vartime(&self, shift: u32) -> CtOption { + Self::overflowing_shr_vartime(self, shift).into() + } + + fn wrapping_shl_vartime(&self, shift: u32) -> Self { + Self::wrapping_shl_vartime(self, shift) + } + + fn wrapping_shr_vartime(&self, shift: u32) -> Self { + Self::wrapping_shr_vartime(self, shift) + } + + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { + if bit_length > Self::BITS { + panic!("The requested bit length ({bit_length}) is larger than the chosen Uint size"); + } + let random = Self::random(rng); + random >> (Self::BITS - bit_length) + } +} + +impl UintLike for BoxedUint { + fn set_bit_vartime(&mut self, index: u32, value: bool) { + if value { + *self |= Self::one() << index + } else { + *self &= (Self::one() << index).not() + } + } + + fn trailing_zeros_vartime(&self) -> u32 { + Self::trailing_zeros(self) + } + + fn trailing_ones_vartime(&self) -> u32 { + Self::trailing_ones_vartime(self) + } + + fn bit_vartime(&self, index: u32) -> bool { + Self::bit_vartime(self, index) + } + + fn sqrt_vartime(&self) -> Self { + Self::sqrt_vartime(self) + } + + fn overflowing_shl_vartime(&self, shift: u32) -> CtOption { + let (res, is_some) = Self::overflowing_shl(self, shift); + CtOption::new(res, is_some) + } + + fn overflowing_shr_vartime(&self, shift: u32) -> CtOption { + let (res, is_some) = Self::overflowing_shr(self, shift); + CtOption::new(res, is_some) + } + + fn wrapping_shl_vartime(&self, shift: u32) -> Self { + Self::wrapping_shl_vartime(self, shift) + } + + fn wrapping_shr_vartime(&self, shift: u32) -> Self { + Self::wrapping_shr_vartime(self, shift) + } + + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { + let random = Self::random(rng, bit_length); + let bits_precision = random.bits_precision(); + random >> (bits_precision - bit_length) + } +} From b72dfb0cb442e9ee85f401add7e64ef88ffebc69 Mon Sep 17 00:00:00 2001 From: "Ganyu (Bruce) Xu" Date: Mon, 25 Dec 2023 18:44:55 +0800 Subject: [PATCH 2/5] Added bits_precision to Uint/BoxedUint agnostic functions --- benches/bench.rs | 34 ++++++++++++++-------------- src/hazmat/miller_rabin.rs | 2 +- src/hazmat/sieve.rs | 16 +++++++++----- src/presets.rs | 45 +++++++++++++++++++++----------------- src/traits.rs | 33 +++++++++++++++++++++------- src/uint_traits.rs | 18 ++++++++++----- 6 files changed, 91 insertions(+), 57 deletions(-) diff --git a/benches/bench.rs b/benches/bench.rs index edd5637..d5ce9ba 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -23,7 +23,7 @@ fn make_rng() -> ChaCha8Rng { } fn make_sieve(rng: &mut impl CryptoRngCore) -> Sieve> { - let start = random_odd_uint::>(rng, Uint::::BITS); + let start = random_odd_uint::>(rng, Uint::::BITS, Uint::::BITS); Sieve::new(&start, Uint::::BITS, false) } @@ -36,12 +36,12 @@ fn bench_sieve(c: &mut Criterion) { let mut group = c.benchmark_group("Sieve"); group.bench_function("(U128) random start", |b| { - b.iter(|| random_odd_uint::>(&mut OsRng, 128)) + b.iter(|| random_odd_uint::>(&mut OsRng, 128, 128)) }); group.bench_function("(U128) creation", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 128), + || random_odd_uint::>(&mut OsRng, 128, 128), |start| Sieve::new(start.as_ref(), 128, false), BatchSize::SmallInput, ) @@ -57,12 +57,12 @@ fn bench_sieve(c: &mut Criterion) { }); group.bench_function("(U1024) random start", |b| { - b.iter(|| random_odd_uint::>(&mut OsRng, 1024)) + b.iter(|| random_odd_uint::>(&mut OsRng, 1024, 1024)) }); group.bench_function("(U1024) creation", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 1024), + || random_odd_uint::>(&mut OsRng, 1024, 1024), |start| Sieve::new(start.as_ref(), 1024, false), BatchSize::SmallInput, ) @@ -84,7 +84,7 @@ fn bench_miller_rabin(c: &mut Criterion) { group.bench_function("(U128) creation", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 128), + || random_odd_uint::>(&mut OsRng, 128, 128), MillerRabin::new, BatchSize::SmallInput, ) @@ -100,7 +100,7 @@ fn bench_miller_rabin(c: &mut Criterion) { group.bench_function("(U1024) creation", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 1024), + || random_odd_uint::>(&mut OsRng, 1024, 1024), MillerRabin::new, BatchSize::SmallInput, ) @@ -193,7 +193,7 @@ fn bench_presets(c: &mut Criterion) { group.bench_function("(U128) Prime test", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 128), + || random_odd_uint::>(&mut OsRng, 128, 128), |num| is_prime_with_rng(&mut OsRng, num.as_ref()), BatchSize::SmallInput, ) @@ -201,7 +201,7 @@ fn bench_presets(c: &mut Criterion) { group.bench_function("(U128) Safe prime test", |b| { b.iter_batched( - || random_odd_uint::>(&mut OsRng, 128), + || random_odd_uint::>(&mut OsRng, 128, 128), |num| is_safe_prime_with_rng(&mut OsRng, num.as_ref()), BatchSize::SmallInput, ) @@ -209,23 +209,23 @@ fn bench_presets(c: &mut Criterion) { let mut rng = make_rng(); group.bench_function("(U128) Random prime", |b| { - b.iter(|| generate_prime_with_rng::>(&mut rng, 128)) + b.iter(|| generate_prime_with_rng::>(&mut rng, 128, 128)) }); let mut rng = make_rng(); group.bench_function("(U1024) Random prime", |b| { - b.iter(|| generate_prime_with_rng::>(&mut rng, 1024)) + b.iter(|| generate_prime_with_rng::>(&mut rng, 1024, 1024)) }); let mut rng = make_rng(); group.bench_function("(U128) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128, 128)) }); group.sample_size(20); let mut rng = make_rng(); group.bench_function("(U1024) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 1024)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 1024, 1024)) }); group.finish(); @@ -235,19 +235,19 @@ fn bench_presets(c: &mut Criterion) { let mut rng = make_rng(); group.bench_function("(U128) Random safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128, 128)) }); // The performance should scale with the prime size, not with the Uint size. // So we should strive for this test's result to be as close as possible // to that of the previous one and as far away as possible from the next one. group.bench_function("(U256) Random 128 bit safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 128, 256)) }); // The upper bound for the previous test. group.bench_function("(U256) Random 256 bit safe prime", |b| { - b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 256)) + b.iter(|| generate_safe_prime_with_rng::>(&mut rng, 256, 256)) }); group.finish(); @@ -258,7 +258,7 @@ fn bench_gmp(c: &mut Criterion) { let mut group = c.benchmark_group("GMP"); fn random(rng: &mut impl CryptoRngCore) -> Integer { - let num = random_odd_uint::>(rng, Uint::::BITS).get(); + let num = random_odd_uint::>(rng, Uint::::BITS, Uint::::BITS).get(); Integer::from_digits(num.as_words(), Order::Lsf) } diff --git a/src/hazmat/miller_rabin.rs b/src/hazmat/miller_rabin.rs index bc0ca07..aa4ff40 100644 --- a/src/hazmat/miller_rabin.rs +++ b/src/hazmat/miller_rabin.rs @@ -192,7 +192,7 @@ mod tests { #[test] fn trivial() { let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901"); - let start = random_odd_uint::(&mut rng, 1024); + let start = random_odd_uint::(&mut rng, 1024, U1024::BITS); for num in Sieve::new(start.as_ref(), 1024, false).take(10) { let mr = MillerRabin::new(Odd::new(num).unwrap()); diff --git a/src/hazmat/sieve.rs b/src/hazmat/sieve.rs index f624d1f..871fcab 100644 --- a/src/hazmat/sieve.rs +++ b/src/hazmat/sieve.rs @@ -15,7 +15,11 @@ use crate::{ /// (that is, with both `0` and `bit_length-1` bits set). /// /// Panics if `bit_length` is 0 or is greater than the bit size of the target `Uint`. -pub fn random_odd_uint(rng: &mut impl CryptoRngCore, bit_length: u32) -> Odd { +pub fn random_odd_uint( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, +) -> Odd { if bit_length == 0 { panic!("Bit length must be non-zero"); } @@ -30,7 +34,7 @@ pub fn random_odd_uint(rng: &mut impl CryptoRngCore, bit_length: u3 }*/ // TODO: not particularly efficient, can be improved by zeroing high bits instead of shifting - let mut random = T::random_bits(rng, bit_length); + let mut random = T::random_bits(rng, bit_length, bits_precision); // Make it odd random.set_bit_vartime(0, true); @@ -293,7 +297,7 @@ mod tests { let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1]; let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901"); - let start = random_odd_uint::(&mut rng, 32).get(); + let start = random_odd_uint::(&mut rng, 32, U64::BITS).get(); for num in Sieve::new(&start, 32, false).take(100) { let num_u64 = u64::from(num); assert!(num_u64.leading_zeros() == 32); @@ -373,7 +377,7 @@ mod tests { #[test] fn random_below_max_length() { for _ in 0..10 { - let r = random_odd_uint::(&mut OsRng, 50).get(); + let r = random_odd_uint::(&mut OsRng, 50, U64::BITS).get(); assert_eq!(r.bits(), 50); } } @@ -381,13 +385,13 @@ mod tests { #[test] #[should_panic(expected = "Bit length must be non-zero")] fn random_odd_uint_0bits() { - let _p = random_odd_uint::(&mut OsRng, 0); + let _p = random_odd_uint::(&mut OsRng, 0, U64::BITS); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn random_odd_uint_too_many_bits() { - let _p = random_odd_uint::(&mut OsRng, 65); + let _p = random_odd_uint::(&mut OsRng, 65, U64::BITS); } #[test] diff --git a/src/presets.rs b/src/presets.rs index 7232dbe..ebde8df 100644 --- a/src/presets.rs +++ b/src/presets.rs @@ -14,8 +14,8 @@ use crate::UintLike; /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn generate_prime(bit_length: u32) -> T { - generate_prime_with_rng(&mut OsRng, bit_length) +pub fn generate_prime(bit_length: u32, bits_precision: u32) -> T { + generate_prime_with_rng(&mut OsRng, bit_length, bits_precision) } /// Returns a random safe prime (that is, such that `(n - 1) / 2` is also prime) @@ -24,8 +24,8 @@ pub fn generate_prime(bit_length: u32) -> T { /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "default-rng")] -pub fn generate_safe_prime(bit_length: u32) -> T { - generate_safe_prime_with_rng(&mut OsRng, bit_length) +pub fn generate_safe_prime(bit_length: u32, bits_precision: u32) -> T { + generate_safe_prime_with_rng(&mut OsRng, bit_length, bits_precision) } /// Checks probabilistically if the given number is prime using [`OsRng`] as the RNG. @@ -52,12 +52,16 @@ pub fn is_safe_prime(num: &T) -> bool { /// Panics if `bit_length` is less than 2, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. -pub fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> T { +pub fn generate_prime_with_rng( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, +) -> T { if bit_length < 2 { panic!("`bit_length` must be 2 or greater."); } loop { - let start = random_odd_uint::(rng, bit_length); + let start = random_odd_uint::(rng, bit_length, bits_precision); let sieve = Sieve::new(start.as_ref(), bit_length, false); for num in sieve { if is_prime_with_rng(rng, &num) { @@ -77,12 +81,13 @@ pub fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_le pub fn generate_safe_prime_with_rng( rng: &mut impl CryptoRngCore, bit_length: u32, + bits_precision: u32, ) -> T { if bit_length < 3 { panic!("`bit_length` must be 3 or greater."); } loop { - let start = random_odd_uint::(rng, bit_length); + let start = random_odd_uint::(rng, bit_length, bits_precision); let sieve = Sieve::new(start.as_ref(), bit_length, true); for num in sieve { if is_safe_prime_with_rng(rng, &num) { @@ -256,7 +261,7 @@ mod tests { #[test] fn prime_generation() { for bit_length in (28..=128).step_by(10) { - let p: U128 = generate_prime(bit_length); + let p: U128 = generate_prime(bit_length, U128::BITS); assert!(p.bits_vartime() == bit_length); assert!(is_prime(&p)); } @@ -265,7 +270,7 @@ mod tests { #[test] fn safe_prime_generation() { for bit_length in (28..=128).step_by(10) { - let p: U128 = generate_safe_prime(bit_length); + let p: U128 = generate_safe_prime(bit_length, U128::BITS); assert!(p.bits_vartime() == bit_length); assert!(is_safe_prime(&p)); } @@ -298,25 +303,25 @@ mod tests { #[test] #[should_panic(expected = "`bit_length` must be 2 or greater")] fn generate_prime_too_few_bits() { - let _p: U64 = generate_prime_with_rng(&mut OsRng, 1); + let _p: U64 = generate_prime_with_rng(&mut OsRng, 1, U64::BITS); } #[test] #[should_panic(expected = "`bit_length` must be 3 or greater")] fn generate_safe_prime_too_few_bits() { - let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 2); + let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 2, U64::BITS); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn generate_prime_too_many_bits() { - let _p: U64 = generate_prime_with_rng(&mut OsRng, 65); + let _p: U64 = generate_prime_with_rng(&mut OsRng, 65, U64::BITS); } #[test] #[should_panic(expected = "The requested bit length (65) is larger than the chosen Uint size")] fn generate_safe_prime_too_many_bits() { - let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 65); + let _p: U64 = generate_safe_prime_with_rng(&mut OsRng, 65, U64::BITS); } fn is_prime_ref(num: Word) -> bool { @@ -327,7 +332,7 @@ mod tests { fn corner_cases_generate_prime() { for bits in 2..5 { for _ in 0..100 { - let p: U64 = generate_prime(bits); + let p: U64 = generate_prime(bits, U64::BITS); let p_word = p.as_words()[0]; assert!(is_prime_ref(p_word)); } @@ -338,7 +343,7 @@ mod tests { fn corner_cases_generate_safe_prime() { for bits in 3..5 { for _ in 0..100 { - let p: U64 = generate_safe_prime(bits); + let p: U64 = generate_safe_prime(bits, U64::BITS); let p_word = p.as_words()[0]; assert!(is_prime_ref(p_word) && is_prime_ref(p_word / 2)); } @@ -376,7 +381,7 @@ mod tests_openssl { // Generate primes, let OpenSSL check them for _ in 0..100 { - let p: U128 = generate_prime(128); + let p: U128 = generate_prime(128, U128::BITS); let p_bn = to_openssl(&p); assert!( openssl_is_prime(&p_bn, &mut ctx), @@ -394,7 +399,7 @@ mod tests_openssl { // Generate random numbers, check if our test agrees with OpenSSL for _ in 0..100 { - let p = random_odd_uint::(&mut OsRng, 128); + let p = random_odd_uint::(&mut OsRng, 128, U128::BITS); let actual = is_prime(p.as_ref()); let p_bn = to_openssl(&p); let expected = openssl_is_prime(&p_bn, &mut ctx); @@ -435,14 +440,14 @@ mod tests_gmp { fn gmp_cross_check() { // Generate primes, let GMP check them for _ in 0..100 { - let p: U128 = generate_prime(128); + let p: U128 = generate_prime(128, U128::BITS); let p_bn = to_gmp(&p); assert!(gmp_is_prime(&p_bn), "GMP reports {p} as composite"); } // Generate primes with GMP, check them for _ in 0..100 { - let start = random_odd_uint::(&mut OsRng, 128); + let start = random_odd_uint::(&mut OsRng, 128, U128::BITS); let start_bn = to_gmp(&start); let p_bn = start_bn.next_prime(); let p = from_gmp(&p_bn); @@ -451,7 +456,7 @@ mod tests_gmp { // Generate random numbers, check if our test agrees with GMP for _ in 0..100 { - let p = random_odd_uint::(&mut OsRng, 128); + let p = random_odd_uint::(&mut OsRng, 128, U128::BITS); let actual = is_prime(p.as_ref()); let p_bn = to_gmp(&p); let expected = gmp_is_prime(&p_bn); diff --git a/src/traits.rs b/src/traits.rs index 458d85c..bed5970 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -15,7 +15,11 @@ pub trait RandomPrimeWithRng { /// Panics if `bit_length` is less than 2, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. - fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; + fn generate_prime_with_rng( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, + ) -> Self; /// Returns a random safe prime (that is, such that `(n - 1) / 2` is also prime) /// of size `bit_length` using the provided RNG. @@ -24,7 +28,11 @@ pub trait RandomPrimeWithRng { /// Panics if `bit_length` is less than 3, or greater than the bit size of the target `Uint`. /// /// See [`is_prime_with_rng`] for details about the performed checks. - fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; + fn generate_safe_prime_with_rng( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, + ) -> Self; /// Checks probabilistically if the given number is prime using the provided RNG. /// @@ -38,11 +46,19 @@ pub trait RandomPrimeWithRng { } impl RandomPrimeWithRng for Uint { - fn generate_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { - generate_prime_with_rng(rng, bit_length) + fn generate_prime_with_rng( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, + ) -> Self { + generate_prime_with_rng(rng, bit_length, bits_precision) } - fn generate_safe_prime_with_rng(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { - generate_safe_prime_with_rng(rng, bit_length) + fn generate_safe_prime_with_rng( + rng: &mut impl CryptoRngCore, + bit_length: u32, + bits_precision: u32, + ) -> Self { + generate_safe_prime_with_rng(rng, bit_length, bits_precision) } fn is_prime_with_rng(&self, rng: &mut impl CryptoRngCore) -> bool { is_prime_with_rng(rng, self) @@ -67,9 +83,10 @@ mod tests { assert!(!U64::from(13u32).is_safe_prime_with_rng(&mut OsRng)); assert!(U64::from(11u32).is_safe_prime_with_rng(&mut OsRng)); - assert!(U64::generate_prime_with_rng(&mut OsRng, 10).is_prime_with_rng(&mut OsRng)); assert!( - U64::generate_safe_prime_with_rng(&mut OsRng, 10).is_safe_prime_with_rng(&mut OsRng) + U64::generate_prime_with_rng(&mut OsRng, 10, U64::BITS).is_prime_with_rng(&mut OsRng) ); + assert!(U64::generate_safe_prime_with_rng(&mut OsRng, 10, U64::BITS) + .is_safe_prime_with_rng(&mut OsRng)); } } diff --git a/src/uint_traits.rs b/src/uint_traits.rs index 015ba4e..0874675 100644 --- a/src/uint_traits.rs +++ b/src/uint_traits.rs @@ -14,7 +14,7 @@ pub trait UintLike: Integer + RandomMod { fn overflowing_shr_vartime(&self, shift: u32) -> CtOption; fn wrapping_shl_vartime(&self, shift: u32) -> Self; fn wrapping_shr_vartime(&self, shift: u32) -> Self; - fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self; + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32, bits_precision: u32) -> Self; } // Uint impls @@ -60,7 +60,13 @@ impl UintLike for Uint { Self::wrapping_shr_vartime(self, shift) } - fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { + /// TODO: bits_precision is required because BoxedUint::random requires bits_precision + /// we can require the user to input the correct precision: + /// in this case the documentation will need to communicate to the users about this + /// requirement, and we could put an assert statement to check + /// + /// We can also accept whatever value and just ignore it. + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32, _bits_precision: u32) -> Self { if bit_length > Self::BITS { panic!("The requested bit length ({bit_length}) is larger than the chosen Uint size"); } @@ -112,9 +118,11 @@ impl UintLike for BoxedUint { Self::wrapping_shr_vartime(self, shift) } - fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32) -> Self { - let random = Self::random(rng, bit_length); - let bits_precision = random.bits_precision(); + fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32, bits_precision: u32) -> Self { + if bit_length > bits_precision { + panic!("The requested bit length ({bit_length}) is larger than the chosen Uint size"); + } + let random = Self::random(rng, bits_precision); random >> (bits_precision - bit_length) } } From 32ff478a965b430b26b5b65f83f871dedcdce7d0 Mon Sep 17 00:00:00 2001 From: "Ganyu (Bruce) Xu" Date: Mon, 25 Dec 2023 18:58:35 +0800 Subject: [PATCH 3/5] Start with failing tests --- src/presets.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/presets.rs b/src/presets.rs index ebde8df..c1ee630 100644 --- a/src/presets.rs +++ b/src/presets.rs @@ -180,7 +180,7 @@ fn _is_prime_with_rng(rng: &mut impl CryptoRngCore, num: &Odd) - #[cfg(test)] mod tests { - use crypto_bigint::{CheckedAdd, Uint, Word, U128, U64}; + use crypto_bigint::{BoxedUint, CheckedAdd, Uint, Word, U128, U64}; use num_prime::nt_funcs::is_prime64; use rand_core::OsRng; @@ -267,6 +267,15 @@ mod tests { } } + #[test] + fn prime_generation_boxed() { + for bit_length in (28..=128).step_by(10) { + let p: BoxedUint = generate_prime(bit_length, 128); + assert!(p.bits_vartime() == bit_length); + assert!(is_prime(&p)); + } + } + #[test] fn safe_prime_generation() { for bit_length in (28..=128).step_by(10) { @@ -276,6 +285,15 @@ mod tests { } } + #[test] + fn safe_prime_generation_boxed() { + for bit_length in (28..=128).step_by(10) { + let p: BoxedUint = generate_safe_prime(bit_length, 128); + assert!(p.bits_vartime() == bit_length); + assert!(is_safe_prime(&p)); + } + } + #[test] fn corner_cases_is_prime() { for num in 0u64..30 { From 89760f9d73d94c4ae0c10b3402ff2fd876dabe80 Mon Sep 17 00:00:00 2001 From: "Ganyu (Bruce) Xu" Date: Mon, 25 Dec 2023 19:06:11 +0800 Subject: [PATCH 4/5] call BoxedUint::one_with_precision to ensure with the same precision --- src/uint_traits.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/uint_traits.rs b/src/uint_traits.rs index 0874675..39ffe1d 100644 --- a/src/uint_traits.rs +++ b/src/uint_traits.rs @@ -78,9 +78,9 @@ impl UintLike for Uint { impl UintLike for BoxedUint { fn set_bit_vartime(&mut self, index: u32, value: bool) { if value { - *self |= Self::one() << index + *self |= Self::one_with_precision(self.bits_precision()) << index } else { - *self &= (Self::one() << index).not() + *self &= (Self::one_with_precision(self.bits_precision()) << index).not() } } From b59be9c919cb00dfc70f0046eb46ea228562ee99 Mon Sep 17 00:00:00 2001 From: "Ganyu (Bruce) Xu" Date: Mon, 25 Dec 2023 20:22:59 +0800 Subject: [PATCH 5/5] overflowing_shl, widen, ones and zeros with precision --- src/hazmat/lucas.rs | 10 +++++++--- src/hazmat/miller_rabin.rs | 9 +++++---- src/uint_traits.rs | 35 +++++++++++++++++++++++++++++++---- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/hazmat/lucas.rs b/src/hazmat/lucas.rs index 9430ec4..ac72fc0 100644 --- a/src/hazmat/lucas.rs +++ b/src/hazmat/lucas.rs @@ -352,7 +352,10 @@ pub fn lucas_test( let q = if q_is_one { one.clone() } else { - let abs_q = ::Monty::new(T::from(abs_q), params.clone()); + let abs_q = ::Monty::new( + T::from(abs_q).widen(candidate.bits_precision()), + params.clone(), + ); if q_is_negative { -abs_q } else { @@ -365,7 +368,7 @@ pub fn lucas_test( let p = if p_is_one { one.clone() } else { - ::Monty::new(T::from(p), params.clone()) + ::Monty::new(T::from(p).widen(candidate.bits_precision()), params.clone()) }; // Compute d-th element of Lucas sequence (U_d(P, Q), V_d(P, Q)), where: @@ -388,7 +391,8 @@ pub fn lucas_test( let mut qk = one.clone(); // keeps Q^k // D in Montgomery representation - note that it can be negative. - let abs_d = ::Monty::new(T::from(abs_d), params); + let abs_d = + ::Monty::new(T::from(abs_d).widen(candidate.bits_precision()), params); let d_m = if d_is_negative { -abs_d } else { abs_d }; for i in (0..d.bits_vartime()).rev() { diff --git a/src/hazmat/miller_rabin.rs b/src/hazmat/miller_rabin.rs index aa4ff40..4dffa94 100644 --- a/src/hazmat/miller_rabin.rs +++ b/src/hazmat/miller_rabin.rs @@ -33,10 +33,11 @@ impl MillerRabin { let minus_one = -one.clone(); // Find `s` and odd `d` such that `candidate - 1 == 2^s * d`. - let (s, d) = if candidate.as_ref() == &T::one() { - (0, T::one()) + let (s, d) = if candidate.as_ref() == &T::one_with_precision(candidate.bits_precision()) { + (0, T::one_with_precision(candidate.bits_precision())) } else { - let candidate_minus_one = candidate.wrapping_sub(&T::one()); + let candidate_minus_one = + candidate.wrapping_sub(&T::one_with_precision(candidate.bits_precision())); let s = candidate_minus_one.trailing_zeros_vartime(); // Will not overflow because `candidate` is odd and greater than 1. let d = candidate_minus_one @@ -86,7 +87,7 @@ impl MillerRabin { /// Perform a Miller-Rabin check with base 2. pub fn test_base_two(&self) -> Primality { - self.test(&T::from(2u32)) + self.test(&T::from(2u32).widen(self.candidate.bits_precision())) } /// Perform a Miller-Rabin check with a random base (in the range `[3, candidate-2]`) diff --git a/src/uint_traits.rs b/src/uint_traits.rs index 39ffe1d..093bbe6 100644 --- a/src/uint_traits.rs +++ b/src/uint_traits.rs @@ -15,6 +15,9 @@ pub trait UintLike: Integer + RandomMod { fn wrapping_shl_vartime(&self, shift: u32) -> Self; fn wrapping_shr_vartime(&self, shift: u32) -> Self; fn random_bits(rng: &mut impl CryptoRngCore, bit_length: u32, bits_precision: u32) -> Self; + fn one_with_precision(bits_precision: u32) -> Self; + fn zero_with_precision(bits_precision: u32) -> Self; + fn widen(&self, bits_precision: u32) -> Self; } // Uint impls @@ -73,6 +76,18 @@ impl UintLike for Uint { let random = Self::random(rng); random >> (Self::BITS - bit_length) } + + fn one_with_precision(_bits_precision: u32) -> Self { + Self::ONE + } + + fn zero_with_precision(_bits_precision: u32) -> Self { + Self::ZERO + } + + fn widen(&self, _bits_precision: u32) -> Self { + *self + } } impl UintLike for BoxedUint { @@ -101,13 +116,13 @@ impl UintLike for BoxedUint { } fn overflowing_shl_vartime(&self, shift: u32) -> CtOption { - let (res, is_some) = Self::overflowing_shl(self, shift); - CtOption::new(res, is_some) + let (res, overflow) = Self::overflowing_shl(self, shift); + CtOption::new(res, !overflow) } fn overflowing_shr_vartime(&self, shift: u32) -> CtOption { - let (res, is_some) = Self::overflowing_shr(self, shift); - CtOption::new(res, is_some) + let (res, overflow) = Self::overflowing_shr(self, shift); + CtOption::new(res, !overflow) } fn wrapping_shl_vartime(&self, shift: u32) -> Self { @@ -125,4 +140,16 @@ impl UintLike for BoxedUint { let random = Self::random(rng, bits_precision); random >> (bits_precision - bit_length) } + + fn zero_with_precision(bits_precision: u32) -> Self { + Self::zero_with_precision(bits_precision) + } + + fn one_with_precision(bits_precision: u32) -> Self { + Self::one_with_precision(bits_precision) + } + + fn widen(&self, bits_precision: u32) -> Self { + self.widen(bits_precision) + } }