Skip to content

Commit

Permalink
Let users pick the thread count
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdplm committed Nov 5, 2024
1 parent e4b1adb commit 7491a06
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/presets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ pub fn generate_safe_prime_with_rng<T: Integer + RandomBits + RandomMod>(
///
/// Panics if the platform is unable to spawn threads.
#[cfg(feature = "rayon")]
pub fn par_generate_prime_with_rng<T>(rng: &mut (impl CryptoRngCore + Send + Sync + Clone), bit_length: u32) -> T
pub fn par_generate_prime_with_rng<T>(
rng: &mut (impl CryptoRngCore + Send + Sync + Clone),
bit_length: u32,
threadcount: usize,
) -> T
where
T: Integer + RandomBits + RandomMod,
{
Expand All @@ -110,8 +114,6 @@ where
}
let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero");

// TODO(dp): decide how to set the threadcount.
let threadcount = core::cmp::max(2, num_cpus::get() / 2);
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(threadcount)
.build()
Expand All @@ -129,7 +131,7 @@ where
Some(prime) => prime,
None => {
drop(threadpool);
par_generate_prime_with_rng(rng, bit_length.get())
par_generate_prime_with_rng(rng, bit_length.get(), threadcount)
}
}
}
Expand All @@ -144,7 +146,11 @@ where
///
/// See [`is_prime_with_rng`] for details about the performed checks.
#[cfg(feature = "rayon")]
pub fn par_generate_safe_prime_with_rng<T>(rng: &mut (impl CryptoRngCore + Send + Sync + Clone), bit_length: u32) -> T
pub fn par_generate_safe_prime_with_rng<T>(
rng: &mut (impl CryptoRngCore + Send + Sync + Clone),
bit_length: u32,
threadcount: usize,
) -> T
where
T: Integer + RandomBits + RandomMod,
{
Expand All @@ -153,8 +159,6 @@ where
}
let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero");

// TODO(dp): decide how to set the threadcount.
let threadcount = core::cmp::max(2, num_cpus::get() / 2);
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(threadcount)
.build()
Expand All @@ -172,7 +176,7 @@ where
Some(prime) => prime,
None => {
drop(threadpool);
par_generate_safe_prime_with_rng(rng, bit_length.get())
par_generate_safe_prime_with_rng(rng, bit_length.get(), threadcount)
}
}
}
Expand Down

0 comments on commit 7491a06

Please sign in to comment.