From 7491a06939d96d1e77f9a352bd51fe86f19183b8 Mon Sep 17 00:00:00 2001 From: David Palm Date: Tue, 5 Nov 2024 14:12:50 +0100 Subject: [PATCH] Let users pick the thread count --- src/presets.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/presets.rs b/src/presets.rs index 118c63c..4028692 100644 --- a/src/presets.rs +++ b/src/presets.rs @@ -101,7 +101,11 @@ pub fn generate_safe_prime_with_rng( /// /// Panics if the platform is unable to spawn threads. #[cfg(feature = "rayon")] -pub fn par_generate_prime_with_rng(rng: &mut (impl CryptoRngCore + Send + Sync + Clone), bit_length: u32) -> T +pub fn par_generate_prime_with_rng( + rng: &mut (impl CryptoRngCore + Send + Sync + Clone), + bit_length: u32, + threadcount: usize, +) -> T where T: Integer + RandomBits + RandomMod, { @@ -110,8 +114,6 @@ where } let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero"); - // TODO(dp): decide how to set the threadcount. - let threadcount = core::cmp::max(2, num_cpus::get() / 2); let threadpool = rayon::ThreadPoolBuilder::new() .num_threads(threadcount) .build() @@ -129,7 +131,7 @@ where Some(prime) => prime, None => { drop(threadpool); - par_generate_prime_with_rng(rng, bit_length.get()) + par_generate_prime_with_rng(rng, bit_length.get(), threadcount) } } } @@ -144,7 +146,11 @@ where /// /// See [`is_prime_with_rng`] for details about the performed checks. #[cfg(feature = "rayon")] -pub fn par_generate_safe_prime_with_rng(rng: &mut (impl CryptoRngCore + Send + Sync + Clone), bit_length: u32) -> T +pub fn par_generate_safe_prime_with_rng( + rng: &mut (impl CryptoRngCore + Send + Sync + Clone), + bit_length: u32, + threadcount: usize, +) -> T where T: Integer + RandomBits + RandomMod, { @@ -153,8 +159,6 @@ where } let bit_length = NonZeroU32::new(bit_length).expect("`bit_length` should be non-zero"); - // TODO(dp): decide how to set the threadcount. - let threadcount = core::cmp::max(2, num_cpus::get() / 2); let threadpool = rayon::ThreadPoolBuilder::new() .num_threads(threadcount) .build() @@ -172,7 +176,7 @@ where Some(prime) => prime, None => { drop(threadpool); - par_generate_safe_prime_with_rng(rng, bit_length.get()) + par_generate_safe_prime_with_rng(rng, bit_length.get(), threadcount) } } }