Skip to content

Commit

Permalink
Prefer core::num::NonZero over crypto_bigint::NonZero where possible (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdplm authored Nov 7, 2024
1 parent 7b88e3e commit 056781d
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 42 deletions.
18 changes: 9 additions & 9 deletions benches/bench.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::num::NonZeroU32;
use core::num::NonZero;

use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::{nlimbs, BoxedUint, Integer, Odd, RandomBits, Uint, U1024, U128, U256};
Expand All @@ -24,12 +24,12 @@ fn make_rng() -> ChaCha8Rng {
}

fn random_odd_uint<T: RandomBits + Integer>(rng: &mut impl CryptoRngCore, bit_length: u32) -> Odd<T> {
random_odd_integer::<T>(rng, NonZeroU32::new(bit_length).unwrap())
random_odd_integer::<T>(rng, NonZero::new(bit_length).unwrap())
}

fn make_sieve<const L: usize>(rng: &mut impl CryptoRngCore) -> Sieve<Uint<L>> {
let start = random_odd_uint::<Uint<L>>(rng, Uint::<L>::BITS);
Sieve::new(start.get(), NonZeroU32::new(Uint::<L>::BITS).unwrap(), false)
Sieve::new(start.get(), NonZero::new(Uint::<L>::BITS).unwrap(), false)
}

fn make_presieved_num<const L: usize>(rng: &mut impl CryptoRngCore) -> Odd<Uint<L>> {
Expand All @@ -47,7 +47,7 @@ fn bench_sieve(c: &mut Criterion) {
group.bench_function("(U128) creation", |b| {
b.iter_batched(
|| random_odd_uint::<U128>(&mut OsRng, 128),
|start| Sieve::new(start.get(), NonZeroU32::new(128).unwrap(), false),
|start| Sieve::new(start.get(), NonZero::new(128).unwrap(), false),
BatchSize::SmallInput,
)
});
Expand All @@ -68,7 +68,7 @@ fn bench_sieve(c: &mut Criterion) {
group.bench_function("(U1024) creation", |b| {
b.iter_batched(
|| random_odd_uint::<U1024>(&mut OsRng, 1024),
|start| Sieve::new(start.get(), NonZeroU32::new(1024).unwrap(), false),
|start| Sieve::new(start.get(), NonZero::new(1024).unwrap(), false),
BatchSize::SmallInput,
)
});
Expand Down Expand Up @@ -378,8 +378,8 @@ fn bench_glass_pumpkin(c: &mut Criterion) {
// Mimics the sequence of checks `glass-pumpkin` does to find a prime.
fn prime_like_gp(bit_length: u32, rng: &mut impl CryptoRngCore) -> BoxedUint {
loop {
let start = random_odd_integer::<BoxedUint>(rng, NonZeroU32::new(bit_length).unwrap()).get();
let sieve = Sieve::new(start, NonZeroU32::new(bit_length).unwrap(), false);
let start = random_odd_integer::<BoxedUint>(rng, NonZero::new(bit_length).unwrap()).get();
let sieve = Sieve::new(start, NonZero::new(bit_length).unwrap(), false);
for num in sieve {
let odd_num = Odd::new(num.clone()).unwrap();

Expand All @@ -402,8 +402,8 @@ fn bench_glass_pumpkin(c: &mut Criterion) {
// Mimics the sequence of checks `glass-pumpkin` does to find a safe prime.
fn safe_prime_like_gp(bit_length: u32, rng: &mut impl CryptoRngCore) -> BoxedUint {
loop {
let start = random_odd_integer::<BoxedUint>(rng, NonZeroU32::new(bit_length).unwrap()).get();
let sieve = Sieve::new(start, NonZeroU32::new(bit_length).unwrap(), true);
let start = random_odd_integer::<BoxedUint>(rng, NonZero::new(bit_length).unwrap()).get();
let sieve = Sieve::new(start, NonZero::new(bit_length).unwrap(), true);
for num in sieve {
let odd_num = Odd::new(num.clone()).unwrap();

Expand Down
8 changes: 5 additions & 3 deletions src/hazmat/gcd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crypto_bigint::{Integer, Limb, NonZero, Word};
use core::num::NonZero;
use crypto_bigint::{Integer, Limb, NonZero as CTNonZero, Word};

/// Calculates the greatest common divisor of `n` and `m`.
/// By definition, `gcd(0, m) == m`.
Expand All @@ -14,7 +15,7 @@ pub(crate) fn gcd_vartime<T: Integer>(n: &T, m: NonZero<Word>) -> Word {
// Normalize input: the resulting (a, b) are both small, a >= b, and b != 0.
let (a, b): (Word, Word) = if n.bits() > Word::BITS {
// `m` is non-zero, so we can unwrap.
let r = n.rem_limb(NonZero::new(Limb::from(m)).expect("divisor should be non-zero here"));
let r = n.rem_limb(CTNonZero::new(Limb::from(m)).expect("divisor should be non-zero here"));
(m, r.0)
} else {
// In this branch `n` is `Word::BITS` bits or shorter,
Expand Down Expand Up @@ -74,7 +75,8 @@ fn binary_gcd(mut n: Word, mut m: Word) -> Word {

#[cfg(test)]
mod tests {
use crypto_bigint::{NonZero, Word, U128};
use core::num::NonZero;
use crypto_bigint::{Word, U128};
use num_bigint::BigUint;
use num_integer::Integer;
use proptest::prelude::*;
Expand Down
4 changes: 2 additions & 2 deletions src/hazmat/jacobi.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Jacobi symbol calculation.
use crypto_bigint::{Integer, Limb, NonZero, Odd, Word};
use crypto_bigint::{Integer, Limb, NonZero as CTNonZero, Odd, Word};

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum JacobiSymbol {
Expand Down Expand Up @@ -99,7 +99,7 @@ pub(crate) fn jacobi_symbol_vartime<T: Integer>(abs_a: Word, a_is_negative: bool
let (result, a_long, p) = swap_long(result, a, p_long);
// Can unwrap here, since `p` is swapped with `a`,
// and `a` would be odd after `reduce_numerator()`.
let a = a_long.rem_limb(NonZero::new(Limb::from(p)).expect("divisor should be non-zero here"));
let a = a_long.rem_limb(CTNonZero::new(Limb::from(p)).expect("divisor should be non-zero here"));
(result, a.0, p)
};

Expand Down
3 changes: 2 additions & 1 deletion src/hazmat/lucas.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Lucas primality test.
use crypto_bigint::{Integer, Limb, Monty, NonZero, Odd, Square, Word};
use core::num::NonZero;
use crypto_bigint::{Integer, Limb, Monty, Odd, Square, Word};

use super::{
gcd::gcd_vartime,
Expand Down
11 changes: 5 additions & 6 deletions src/hazmat/miller_rabin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Miller-Rabin primality test.
use crypto_bigint::{Integer, Limb, Monty, NonZero, Odd, PowBoundedExp, RandomMod, Square};
use crypto_bigint::{Integer, Limb, Monty, NonZero as CTNonZero, Odd, PowBoundedExp, RandomMod, Square};
use rand_core::CryptoRngCore;

use super::Primality;
Expand Down Expand Up @@ -115,7 +115,7 @@ impl<T: Integer + RandomMod> MillerRabin<T> {

let range = self.candidate.wrapping_sub(&T::from(4u32));
// Can unwrap here since `candidate` is odd, and `candidate >= 4` (as checked above)
let range_nonzero = NonZero::new(range).expect("the range should be non-zero by construction");
let range_nonzero = CTNonZero::new(range).expect("the range should be non-zero by construction");
// This should not overflow as long as `random_mod()` behaves according to the contract
// (that is, returns a number within the given range).
let random = T::random_mod(rng, &range_nonzero)
Expand All @@ -136,9 +136,8 @@ impl<T: Integer + RandomMod> MillerRabin<T> {

#[cfg(test)]
mod tests {

use alloc::format;
use core::num::NonZeroU32;
use core::num::NonZero;

use crypto_bigint::{Integer, Odd, RandomMod, Uint, U1024, U128, U1536, U64};
use rand_chacha::ChaCha8Rng;
Expand Down Expand Up @@ -197,8 +196,8 @@ mod tests {
#[test]
fn trivial() {
let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
let start = random_odd_integer::<U1024>(&mut rng, NonZeroU32::new(1024).unwrap());
for num in Sieve::new(start.get(), NonZeroU32::new(1024).unwrap(), false).take(10) {
let start = random_odd_integer::<U1024>(&mut rng, NonZero::new(1024).unwrap());
for num in Sieve::new(start.get(), NonZero::new(1024).unwrap(), false).take(10) {
let mr = MillerRabin::new(Odd::new(num).unwrap());

// Trivial tests, must always be true.
Expand Down
26 changes: 13 additions & 13 deletions src/hazmat/sieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ mod tests {

use alloc::format;
use alloc::vec::Vec;
use core::num::NonZeroU32;
use core::num::NonZero;

use crypto_bigint::U64;
use num_prime::nt_funcs::factorize64;
Expand All @@ -264,8 +264,8 @@ mod tests {
let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];

let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
let start = random_odd_integer::<U64>(&mut rng, NonZeroU32::new(32).unwrap()).get();
for num in Sieve::new(start, NonZeroU32::new(32).unwrap(), false).take(100) {
let start = random_odd_integer::<U64>(&mut rng, NonZero::new(32).unwrap()).get();
for num in Sieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
let num_u64 = u64::from(num);
assert!(num_u64.leading_zeros() == 32);

Expand All @@ -280,9 +280,9 @@ mod tests {
let max_prime = SMALL_PRIMES[SMALL_PRIMES.len() - 1];

let mut rng = ChaCha8Rng::from_seed(*b"01234567890123456789012345678901");
let start = random_odd_integer::<crypto_bigint::BoxedUint>(&mut rng, NonZeroU32::new(32).unwrap()).get();
let start = random_odd_integer::<crypto_bigint::BoxedUint>(&mut rng, NonZero::new(32).unwrap()).get();

for num in Sieve::new(start, NonZeroU32::new(32).unwrap(), false).take(100) {
for num in Sieve::new(start, NonZero::new(32).unwrap(), false).take(100) {
// For 32-bit targets
#[allow(clippy::useless_conversion)]
let num_u64: u64 = num.as_words()[0].into();
Expand All @@ -296,7 +296,7 @@ mod tests {
}

fn check_sieve(start: u32, bit_length: u32, safe_prime: bool, reference: &[u32]) {
let test = Sieve::new(U64::from(start), NonZeroU32::new(bit_length).unwrap(), safe_prime).collect::<Vec<_>>();
let test = Sieve::new(U64::from(start), NonZero::new(bit_length).unwrap(), safe_prime).collect::<Vec<_>>();
assert_eq!(test.len(), reference.len());
for (x, y) in test.iter().zip(reference.iter()) {
assert_eq!(x, &U64::from(*y));
Expand Down Expand Up @@ -351,42 +351,42 @@ mod tests {
#[test]
#[should_panic(expected = "The requested bit length (65) is larger than the precision of `start`")]
fn sieve_too_many_bits() {
let _sieve = Sieve::new(U64::ONE, NonZeroU32::new(65).unwrap(), false);
let _sieve = Sieve::new(U64::ONE, NonZero::new(65).unwrap(), false);
}

#[test]
fn random_below_max_length() {
for _ in 0..10 {
let r = random_odd_integer::<U64>(&mut OsRng, NonZeroU32::new(50).unwrap()).get();
let r = random_odd_integer::<U64>(&mut OsRng, NonZero::new(50).unwrap()).get();
assert_eq!(r.bits(), 50);
}
}

#[test]
#[should_panic(expected = "try_random_bits() failed: BitLengthTooLarge { bit_length: 65, bits_precision: 64 }")]
fn random_odd_uint_too_many_bits() {
let _p = random_odd_integer::<U64>(&mut OsRng, NonZeroU32::new(65).unwrap());
let _p = random_odd_integer::<U64>(&mut OsRng, NonZero::new(65).unwrap());
}

#[test]
fn sieve_derived_traits() {
let s = Sieve::new(U64::ONE, NonZeroU32::new(10).unwrap(), false);
let s = Sieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
// Debug
assert!(format!("{s:?}").starts_with("Sieve"));
// Clone
assert_eq!(s.clone(), s);

// PartialEq
let s2 = Sieve::new(U64::ONE, NonZeroU32::new(10).unwrap(), false);
let s2 = Sieve::new(U64::ONE, NonZero::new(10).unwrap(), false);
assert_eq!(s, s2);
let s3 = Sieve::new(U64::ONE, NonZeroU32::new(12).unwrap(), false);
let s3 = Sieve::new(U64::ONE, NonZero::new(12).unwrap(), false);
assert_ne!(s, s3);
}

#[test]
fn sieve_with_max_start() {
let start = U64::MAX;
let mut sieve = Sieve::new(start, NonZeroU32::new(U64::BITS).unwrap(), false);
let mut sieve = Sieve::new(start, NonZero::new(U64::BITS).unwrap(), false);
assert!(sieve.next().is_none());
}
}
16 changes: 8 additions & 8 deletions src/presets.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::num::NonZeroU32;
use core::num::NonZero;

use crypto_bigint::{Integer, Limb, Odd, RandomBits, RandomMod};
use rand_core::CryptoRngCore;
Expand Down Expand Up @@ -54,7 +54,7 @@ pub fn generate_prime_with_rng<T: Integer + RandomBits + RandomMod>(
if bit_length < 2 {
panic!("`bit_length` must be 2 or greater.");
}
let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero");
let bit_length = NonZero::new(bit_length).expect("`bit_length` should be non-zero");
// Empirically, this loop is traversed 1 time.
loop {
let start = random_odd_integer::<T>(rng, bit_length);
Expand All @@ -78,7 +78,7 @@ pub fn generate_safe_prime_with_rng<T: Integer + RandomBits + RandomMod>(
if bit_length < 3 {
panic!("`bit_length` must be 3 or greater.");
}
let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero");
let bit_length = NonZero::new(bit_length).expect("`bit_length` should be non-zero");
loop {
let start = random_odd_integer::<T>(rng, bit_length);
let sieve = Sieve::new(start.get(), bit_length, true);
Expand Down Expand Up @@ -371,7 +371,7 @@ mod tests {
#[cfg(feature = "tests-openssl")]
mod tests_openssl {
use alloc::format;
use core::num::NonZeroU32;
use core::num::NonZero;

use crypto_bigint::U128;
use openssl::bn::{BigNum, BigNumContext};
Expand Down Expand Up @@ -413,7 +413,7 @@ mod tests_openssl {

// Generate random numbers, check if our test agrees with OpenSSL
for _ in 0..100 {
let p = random_odd_integer::<U128>(&mut OsRng, NonZeroU32::new(128).unwrap());
let p = random_odd_integer::<U128>(&mut OsRng, NonZero::new(128).unwrap());
let actual = is_prime(p.as_ref());
let p_bn = to_openssl(&p);
let expected = openssl_is_prime(&p_bn, &mut ctx);
Expand All @@ -428,7 +428,7 @@ mod tests_openssl {
#[cfg(test)]
#[cfg(feature = "tests-gmp")]
mod tests_gmp {
use core::num::NonZeroU32;
use core::num::NonZero;

use crypto_bigint::U128;
use rand_core::OsRng;
Expand Down Expand Up @@ -463,7 +463,7 @@ mod tests_gmp {

// Generate primes with GMP, check them
for _ in 0..100 {
let start = random_odd_integer::<U128>(&mut OsRng, NonZeroU32::new(128).unwrap());
let start = random_odd_integer::<U128>(&mut OsRng, NonZero::new(128).unwrap());
let start_bn = to_gmp(&start);
let p_bn = start_bn.next_prime();
let p = from_gmp(&p_bn);
Expand All @@ -472,7 +472,7 @@ mod tests_gmp {

// Generate random numbers, check if our test agrees with GMP
for _ in 0..100 {
let p = random_odd_integer::<U128>(&mut OsRng, NonZeroU32::new(128).unwrap());
let p = random_odd_integer::<U128>(&mut OsRng, NonZero::new(128).unwrap());
let actual = is_prime(p.as_ref());
let p_bn = to_gmp(&p);
let expected = gmp_is_prime(&p_bn);
Expand Down

0 comments on commit 056781d

Please sign in to comment.