diff --git a/.gitmodules b/.gitmodules index f1b208d..a939eb3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "mbedtls"] path = mbedtls - url = https://github.com/Mbed-TLS/mbedtls + url = https://github.com/espressif/mbedtls diff --git a/Cargo.toml b/Cargo.toml index 16224b2..68fab06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,9 @@ static_cell = { version = "=1.2", features = ["nightly"] } esp-mbedtls = { path = "./esp-mbedtls" } -[target.xtensa-esp32s3-none-elf.dependencies] +[[example]] +name = "crypto_self_test" +required-features = ["esp-wifi/wifi-logs"] [[example]] name = "async_client" diff --git a/esp-mbedtls-sys/headers/esp32c3/config.h b/esp-mbedtls-sys/headers/esp32c3/config.h index 7eb82a8..8b1b00e 100644 --- a/esp-mbedtls-sys/headers/esp32c3/config.h +++ b/esp-mbedtls-sys/headers/esp32c3/config.h @@ -3609,6 +3609,8 @@ /* MPI / BIGNUM options */ //#define MBEDTLS_MPI_WINDOW_SIZE 2 /**< Maximum window size used. */ //#define MBEDTLS_MPI_MAX_SIZE 1024 /**< Maximum number of bytes for usable MPIs. */ +#define MBEDTLS_MPI_EXP_MOD_ALT +// #define MBEDTLS_MPI_MUL_MPI_ALT /* CTR_DRBG options */ //#define MBEDTLS_CTR_DRBG_ENTROPY_LEN 48 /**< Amount of entropy used per seed by default (48 with SHA-512, 32 with SHA-256) */ diff --git a/esp-mbedtls-sys/headers/esp32s3/config.h b/esp-mbedtls-sys/headers/esp32s3/config.h index 1f366db..8398e71 100644 --- a/esp-mbedtls-sys/headers/esp32s3/config.h +++ b/esp-mbedtls-sys/headers/esp32s3/config.h @@ -3609,6 +3609,9 @@ /* MPI / BIGNUM options */ //#define MBEDTLS_MPI_WINDOW_SIZE 2 /**< Maximum window size used. */ //#define MBEDTLS_MPI_MAX_SIZE 1024 /**< Maximum number of bytes for usable MPIs. */ +// #define MBEDTLS_BIGNUM_ALT +#define MBEDTLS_MPI_EXP_MOD_ALT +#define MBEDTLS_MPI_MUL_MPI_ALT /* CTR_DRBG options */ //#define MBEDTLS_CTR_DRBG_ENTROPY_LEN 48 /**< Amount of entropy used per seed by default (48 with SHA-512, 32 with SHA-256) */ diff --git a/esp-mbedtls-sys/src/include/esp32.rs b/esp-mbedtls-sys/src/include/esp32.rs index 34a5e23..4635225 100644 --- a/esp-mbedtls-sys/src/include/esp32.rs +++ b/esp-mbedtls-sys/src/include/esp32.rs @@ -4808,6 +4808,7 @@ extern "C" { /// buffer of length \p blen Bytes. It may be \c NULL if /// \p blen is zero. /// \param blen The length of \p buf in Bytes. + /// \param md_alg The hash algorithm used to hash the original data. /// \param f_rng_blind The RNG function used for blinding. This must not be /// \c NULL. /// \param p_rng_blind The RNG context to be passed to \p f_rng. This may be diff --git a/esp-mbedtls-sys/src/include/esp32c3.rs b/esp-mbedtls-sys/src/include/esp32c3.rs index 02e7202..c5d4029 100644 --- a/esp-mbedtls-sys/src/include/esp32c3.rs +++ b/esp-mbedtls-sys/src/include/esp32c3.rs @@ -4810,6 +4810,7 @@ extern "C" { /// buffer of length \p blen Bytes. It may be \c NULL if /// \p blen is zero. /// \param blen The length of \p buf in Bytes. + /// \param md_alg The hash algorithm used to hash the original data. /// \param f_rng_blind The RNG function used for blinding. This must not be /// \c NULL. /// \param p_rng_blind The RNG context to be passed to \p f_rng. This may be diff --git a/esp-mbedtls-sys/src/include/esp32s2.rs b/esp-mbedtls-sys/src/include/esp32s2.rs index 34a5e23..4635225 100644 --- a/esp-mbedtls-sys/src/include/esp32s2.rs +++ b/esp-mbedtls-sys/src/include/esp32s2.rs @@ -4808,6 +4808,7 @@ extern "C" { /// buffer of length \p blen Bytes. It may be \c NULL if /// \p blen is zero. /// \param blen The length of \p buf in Bytes. + /// \param md_alg The hash algorithm used to hash the original data. /// \param f_rng_blind The RNG function used for blinding. This must not be /// \c NULL. /// \param p_rng_blind The RNG context to be passed to \p f_rng. This may be diff --git a/esp-mbedtls-sys/src/include/esp32s3.rs b/esp-mbedtls-sys/src/include/esp32s3.rs index 34a5e23..4635225 100644 --- a/esp-mbedtls-sys/src/include/esp32s3.rs +++ b/esp-mbedtls-sys/src/include/esp32s3.rs @@ -4808,6 +4808,7 @@ extern "C" { /// buffer of length \p blen Bytes. It may be \c NULL if /// \p blen is zero. /// \param blen The length of \p buf in Bytes. + /// \param md_alg The hash algorithm used to hash the original data. /// \param f_rng_blind The RNG function used for blinding. This must not be /// \c NULL. /// \param p_rng_blind The RNG context to be passed to \p f_rng. This may be diff --git a/esp-mbedtls/Cargo.toml b/esp-mbedtls/Cargo.toml index 3221214..14365a2 100644 --- a/esp-mbedtls/Cargo.toml +++ b/esp-mbedtls/Cargo.toml @@ -8,10 +8,12 @@ esp-mbedtls-sys = { path = "../esp-mbedtls-sys" } log = "0.4.17" embedded-io = { version = "0.6.1" } embedded-io-async = { version = "0.6.0", optional = true } +crypto-bigint = { version = "0.5.3", default-features = false, features = ["extra-sizes"] } esp32-hal = { version = "0.16.0", optional = true } esp32c3-hal = { version = "0.13.0", optional = true } esp32s2-hal = { version = "0.13.0", optional = true } esp32s3-hal = { version = "0.13.0", optional = true } +cfg-if = "1.0.0" [features] async = ["dep:embedded-io-async"] diff --git a/esp-mbedtls/src/bignum.rs b/esp-mbedtls/src/bignum.rs new file mode 100644 index 0000000..cc18265 --- /dev/null +++ b/esp-mbedtls/src/bignum.rs @@ -0,0 +1,841 @@ +#![allow(non_snake_case)] + +use crate::hal::peripherals::RSA; +use crate::hal::prelude::nb; +use crate::hal::rsa::{ + operand_sizes, Rsa, RsaModularExponentiation, RsaModularMultiplication, RsaMultiplication, +}; + +use crypto_bigint::*; + +use esp_mbedtls_sys::bindings::*; +use esp_mbedtls_sys::c_types::*; + +macro_rules! error_checked { + ($block:expr) => {{ + let res = $block; + if res != 0 { + panic!("Non zero error {:?}", res); + } else { + // Do nothing for now + } + }}; +} + +#[cfg(feature = "esp32")] +const SOC_RSA_MAX_BIT_LEN: usize = 4096; +#[cfg(feature = "esp32c3")] +const SOC_RSA_MAX_BIT_LEN: usize = 3072; +#[cfg(feature = "esp32s2")] +const SOC_RSA_MAX_BIT_LEN: usize = 4096; +#[cfg(feature = "esp32s3")] +const SOC_RSA_MAX_BIT_LEN: usize = 4096; + +/// An error occurred while reading from or writing to a file. +const MBEDTLS_ERR_MPI_FILE_IO_ERROR: c_int = -0x0002; +/// Bad input parameters to function. +const MBEDTLS_ERR_MPI_BAD_INPUT_DATA: c_int = -0x0004; +/// There is an invalid character in the digit string. +const MBEDTLS_ERR_MPI_INVALID_CHARACTER: c_int = -0x0006; +/// The buffer is too small to write to. +const MBEDTLS_ERR_MPI_BUFFER_TOO_SMALL: c_int = -0x0008; +/// The input arguments are negative or result in illegal output. +const MBEDTLS_ERR_MPI_NEGATIVE_VALUE: c_int = -0x000A; +/// The input argument for division is zero, which is not allowed. +const MBEDTLS_ERR_MPI_DIVISION_BY_ZERO: c_int = -0x000C; +/// The input arguments are not acceptable. +const MBEDTLS_ERR_MPI_NOT_ACCEPTABLE: c_int = -0x000E; +/// Memory allocation failed. +const MBEDTLS_ERR_MPI_ALLOC_FAILED: c_int = -0x0010; + +fn compute_mprime(M: *const mbedtls_mpi) -> u32 { + let mut t: u64 = 1; + let mut two_2_i_minus_1: u64 = 2; // 2^(i-1) + let mut two_2_i: u64 = 4; // 2^i + let n = unsafe { (*M).private_p.read() } as u64; + + for i in 2..=32 { + if n * t % two_2_i >= two_2_i_minus_1 { + t += two_2_i_minus_1; + } + + two_2_i_minus_1 <<= 1; + two_2_i <<= 1; + } + + return (u32::MAX as u64 - t + 1) as u32; +} + +/// Calculate Rinv = RR^2 mod M, where: +/// +/// R = b^n where b = 2^32, n=num_words, +/// R = 2^N (where N=num_bits) +/// RR = R^2 = 2^(2*N) (where N=num_bits=num_words*32) +/// +/// This calculation is computationally expensive (mbedtls_mpi_mod_mpi) +/// so caller should cache the result where possible. +/// +/// DO NOT call this function while holding esp_mpi_enable_hardware_hw_op(). +unsafe fn calculate_rinv( + prec_RR: *mut mbedtls_mpi, + M: *const mbedtls_mpi, + num_words: usize, +) -> c_int { + let ret = 0; + let num_bits = num_words * 32; + let mut RR = mbedtls_mpi { + private_s: 0, + private_n: 0, + private_p: core::ptr::null_mut(), + }; + + mbedtls_mpi_init(&mut RR); + error_checked!(mbedtls_mpi_set_bit(&mut RR, num_bits * 2, 1)); + error_checked!(mbedtls_mpi_mod_mpi(prec_RR, &RR, M)); + + mbedtls_mpi_free(&mut RR); + + ret +} + +/// Z = X ^ Y mod M +#[no_mangle] +pub unsafe extern "C" fn mbedtls_mpi_exp_mod( + Z: *mut mbedtls_mpi, + X: *const mbedtls_mpi, + Y: *const mbedtls_mpi, + M: *const mbedtls_mpi, + prec_RR: *mut mbedtls_mpi, +) -> c_int { + let mut ret = 0; + let x_words = mpi_words(X); + let y_words = mpi_words(Y); + let m_words = mpi_words(M); + + // All numbers must be the lame length, so choose longest number as + // cardinal length of operation + let num_words = calculate_hw_words(core::cmp::max(m_words, core::cmp::max(x_words, y_words))); + + if num_words * 32 > SOC_RSA_MAX_BIT_LEN { + return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; + } + + if mbedtls_mpi_cmp_int(M, 0) <= 0 || (*M).private_p.read() & 1 == 0 { + return MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + + if mbedtls_mpi_cmp_int(Y, 0) < 0 { + return MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + + if mbedtls_mpi_cmp_int(Y, 0) == 0 { + return mbedtls_mpi_lset(Z, 1); + } + + // Determine RR pointer, either _RR for cached value or local RR_new + let mut rinv: *mut mbedtls_mpi; + + let mut rinv_new = mbedtls_mpi { + private_s: 0, + private_n: 0, + private_p: core::ptr::null_mut(), + }; + + if prec_RR.is_null() { + mbedtls_mpi_init(&mut rinv_new); + rinv = &mut rinv_new; + } else { + rinv = prec_RR; + } + + if (*rinv).private_p.is_null() { + calculate_rinv(rinv, M, num_words); + } + + ret = mbedtls_mpi_grow(Z, m_words); + let mut rsa = Rsa::new(RSA::steal()); + nb::block!(rsa.ready()).unwrap(); + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + match num_words * 4 { + U512::BYTES => { + const OP_SIZE: usize = U512::BYTES; + let mut base = [0u8; OP_SIZE]; + let mut exponent = [0u8; OP_SIZE]; + let mut modulus = [0u8; OP_SIZE]; + let mut r = [0u8; OP_SIZE]; + copy_bytes((*X).private_p as *const u8, base.as_mut_ptr(), x_words * 4); + copy_bytes( + (*Y).private_p as *const u8, + exponent.as_mut_ptr(), + y_words * 4, + ); + copy_bytes( + (*M).private_p as *const u8, + modulus.as_mut_ptr(), + m_words * 4, + ); + copy_bytes( + (*rinv).private_p as *const u8, + r.as_mut_ptr(), + mpi_words(rinv) * 4, + ); + let mut mod_exp = RsaModularExponentiation::::new( + &mut rsa, + &exponent, // exponent (Y) Y_MEM + &modulus, // modulus (M) M_MEM + compute_mprime(M), // mprime + ); + let mut out = [0u8; OP_SIZE]; + mod_exp.start_exponentiation( + &base, // X_MEM + &r, // Z_MEM + ); + + mod_exp.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, m_words); + } + U1024::BYTES => { + const OP_SIZE: usize = U1024::BYTES; + let mut base = [0u8; OP_SIZE]; + let mut exponent = [0u8; OP_SIZE]; + let mut modulus = [0u8; OP_SIZE]; + let mut r = [0u8; OP_SIZE]; + copy_bytes((*X).private_p as *const u8, base.as_mut_ptr(), x_words * 4); + copy_bytes( + (*Y).private_p as *const u8, + exponent.as_mut_ptr(), + y_words * 4, + ); + copy_bytes( + (*M).private_p as *const u8, + modulus.as_mut_ptr(), + m_words * 4, + ); + copy_bytes( + (*rinv).private_p as *const u8, + r.as_mut_ptr(), + mpi_words(rinv) * 4, + ); + let mut mod_exp = RsaModularExponentiation::::new( + &mut rsa, + &exponent, // exponent (Y) Y_MEM + &modulus, // modulus (M) M_MEM + compute_mprime(M), // mprime + ); + let mut out = [0u8; OP_SIZE]; + mod_exp.start_exponentiation( + &base, // X_MEM + &r, // Z_MEM + ); + + mod_exp.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, m_words); + } + U2048::BYTES => { + const OP_SIZE: usize = U2048::BYTES; + let mut base = [0u8; OP_SIZE]; + let mut exponent = [0u8; OP_SIZE]; + let mut modulus = [0u8; OP_SIZE]; + let mut r = [0u8; OP_SIZE]; + copy_bytes((*X).private_p as *const u8, base.as_mut_ptr(), x_words * 4); + copy_bytes( + (*Y).private_p as *const u8, + exponent.as_mut_ptr(), + y_words * 4, + ); + copy_bytes( + (*M).private_p as *const u8, + modulus.as_mut_ptr(), + m_words * 4, + ); + copy_bytes( + (*rinv).private_p as *const u8, + r.as_mut_ptr(), + mpi_words(rinv) * 4, + ); + let mut mod_exp = RsaModularExponentiation::::new( + &mut rsa, + &exponent, // exponent (Y) Y_MEM + &modulus, // modulus (M) M_MEM + compute_mprime(M), // mprime + ); + let mut out = [0u8; OP_SIZE]; + mod_exp.start_exponentiation( + &base, // X_MEM + &r, // Z_MEM + ); + + mod_exp.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, m_words); + } + U4096::BYTES => { + const OP_SIZE: usize = U4096::BYTES; + let mut base = [0u8; OP_SIZE]; + let mut exponent = [0u8; OP_SIZE]; + let mut modulus = [0u8; OP_SIZE]; + let mut r = [0u8; OP_SIZE]; + copy_bytes((*X).private_p as *const u8, base.as_mut_ptr(), x_words * 4); + copy_bytes( + (*Y).private_p as *const u8, + exponent.as_mut_ptr(), + y_words * 4, + ); + copy_bytes( + (*M).private_p as *const u8, + modulus.as_mut_ptr(), + m_words * 4, + ); + copy_bytes( + (*rinv).private_p as *const u8, + r.as_mut_ptr(), + mpi_words(rinv) * 4, + ); + let mut mod_exp = RsaModularExponentiation::::new( + &mut rsa, + &exponent, // exponent (Y) Y_MEM + &modulus, // modulus (M) M_MEM + compute_mprime(M), // mprime + ); + let mut out = [0u8; OP_SIZE]; + mod_exp.start_exponentiation( + &base, // X_MEM + &r, // Z_MEM + ); + + mod_exp.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, m_words); + } + op => { + todo!("Implement operand: {}", op); + } + } + + // Compensate for negative X + if (*X).private_s == -1 && ((*Y).private_p.read() & 1) != 0 { + (*Z).private_s = -1; + error_checked!(mbedtls_mpi_add_mpi(Z, M, Z)); + } else { + (*Z).private_s = 1; + } + + if prec_RR.is_null() { + mbedtls_mpi_free(&mut rinv_new); + } + ret +} + +#[inline] +const fn bits_to_words(bits: usize) -> usize { + (bits + 31) / 32 +} + +const fn calculate_hw_words(words: usize) -> usize { + // Round up number of words to nearest + // 512 bit (16 word) block count. + #[cfg(feature = "esp32")] + return (words + 0xF) & !0xF; + #[cfg(not(feature = "esp32"))] + words +} + +unsafe fn mpi_words(X: *const mbedtls_mpi) -> usize { + for i in (0..=(*X).private_n).rev() { + if (*X).private_p.add(i - 1).read() != 0 { + return i; + } + } + 0 +} + +/// Deal with the case when X & Y are too long for the hardware unit, by splitting one operand +/// into two halves. +/// +/// Y must be the longer operand +/// +/// Slice Y into Yp, Ypp such that: +/// Yp = lower 'b' bits of Y +/// Ypp = upper 'b' bits of Y (right shifted) +/// +/// Such that +/// Z = X * Y +/// Z = X * (Yp + Ypp< c_int { + let mut ret = 0; + + // Rather than slicing in two on bits we slice on limbs (32 bit words) + let words_slice: usize = y_words / 2; + + // Holds the lower bits of Y (declared to reuse Y's array contents to save on copying) + let yp: mbedtls_mpi = mbedtls_mpi { + private_p: (*Y).private_p, + private_n: words_slice, + private_s: (*Y).private_s, + }; + + // Holds the upper bits of Y, right shifted (also reuse Y's array contents) + let ypp: mbedtls_mpi = mbedtls_mpi { + private_p: (*Y).private_p.add(words_slice), + private_n: y_words - words_slice, + private_s: (*Y).private_s, + }; + + let mut x_temp = mbedtls_mpi { + private_s: 0, + private_n: 0, + private_p: core::ptr::null_mut(), + }; + + mbedtls_mpi_init(&mut x_temp); + + /* Grow Z to result size early, avoid interim allocations */ + mbedtls_mpi_grow(Z, z_words); + + error_checked!(mbedtls_mpi_mul_mpi(&mut x_temp, X, &yp)); + + // Z = b_upper * B + error_checked!(mbedtls_mpi_mul_mpi(Z, X, &ypp)); + + // X = X << b + error_checked!(mbedtls_mpi_shift_l(Z, words_slice * 32)); + + // X += Xtemp + error_checked!(mbedtls_mpi_add_mpi(Z, Z, &x_temp)); + + mbedtls_mpi_free(&mut x_temp); + + ret +} + +unsafe fn mbedtls_mpi_mult_mpi_failover_mod_mult( + X: *mut mbedtls_mpi, + A: *const mbedtls_mpi, + B: *const mbedtls_mpi, + z_words: usize, +) -> c_int { + let mut ret = 0; + + let a_bits = mbedtls_mpi_bitlen(A); + let b_bits = mbedtls_mpi_bitlen(B); + // TODO: We can have the words value from the mpi + let a_words = bits_to_words(a_bits); + let b_words = bits_to_words(b_bits); + let hw_words = calculate_hw_words(z_words); + + log::info!("hw_words {}", hw_words); + + const OP_SIZE: usize = U4096::BYTES; + let mut operand_a = [0u8; OP_SIZE]; + let mut operand_b = [0u8; OP_SIZE]; + copy_bytes( + (*A).private_p as *const u8, + operand_a.as_mut_ptr(), + a_words * 4, + ); + copy_bytes( + (*B).private_p as *const u8, + operand_b.as_mut_ptr(), + b_words * 4, + ); + + let mut rsa = Rsa::new(RSA::steal()); + nb::block!(rsa.ready()).unwrap(); + let mut calc = RsaModularMultiplication::::new( + &mut rsa, + &operand_a, + &operand_b, + &[u8::MAX; U4096::BYTES], + 1, + ); + calc.start_modular_multiplication(&[0u8; U4096::BYTES]); + + let mut out = [0u8; OP_SIZE]; + calc.read_results(&mut out); + + // Grow X to result size early, avoid interim allocations + error_checked!(mbedtls_mpi_grow(X, hw_words)); + + copy_bytes(out.as_ptr() as *mut u32, (*X).private_p, hw_words); + (*X).private_s = (*A).private_s * (*B).private_s; + + // Relevant: https://github.com/espressif/esp-idf/issues/11850 + // + // If z_words < mpi_words(Z) (the actual words taken by the MPI result), + // the assert fails due to unsigned arithmetic - most likely hardware + // peripheral has produced an incorrect result for MPI operation. + // This can happen if data fed to the peripheral register was incorrect. + // + // z_words is calculated as the worst-case possible size of the result + // MPI Z. The difference between z_words and the actual words taken by + // the MPI result (mpi_words(Z)) can be a maximum of 1 word. + // The value z_bits (actual bits taken by the MPI result) is calculated + // as x_bits + y_bits bits, however, in some cases, z_bits can be + // x_bits + y_bits - 1 bits (see example below). + // 0b1111 * 0b1111 = 0b11100001 -> 8 bits + // 0b1000 * 0b1000 = 0b01000000 -> 7 bits. + // The code rounds up to the nearest word size, so the maximum difference + // could be of only 1 word. The assert handles this. + assert!(z_words - unsafe { mpi_words(X) } <= 1); + + ret +} + +#[inline] +unsafe fn copy_bytes(src: *const T, dst: *mut T, count: usize) +where + T: Copy, +{ + for i in 0..count { + *dst.add(i) = *src.add(i); + } +} + +// Baseline multiplication: Z = X * Y (HAC 14.12) +#[no_mangle] +pub unsafe extern "C" fn mbedtls_mpi_mul_mpi( + Z: *mut mbedtls_mpi, + X: *const mbedtls_mpi, + Y: *const mbedtls_mpi, +) -> c_int { + let mut ret = 0; + + let x_bits = mbedtls_mpi_bitlen(X); + let y_bits = mbedtls_mpi_bitlen(Y); + // TODO: We can have the words value from the mpi + let x_words = bits_to_words(x_bits); + let y_words = bits_to_words(y_bits); + let z_words = bits_to_words(x_bits + y_bits); + let hw_words = calculate_hw_words(core::cmp::max(x_words, y_words)); + + // Short-circuit eval if either argument is 0 or 1. + // + // This is needed as the mpi modular division + // argument will sometimes call in here when one + // argument is too large for the hardware unit, but other + // argument is zero or one. + if x_bits == 0 || y_bits == 0 { + mbedtls_mpi_lset(Z, 0); + return 0; + } + if x_bits == 1 { + ret = mbedtls_mpi_copy(Z, Y); + (*Z).private_s *= (*X).private_s; + return ret; + } + if y_bits == 1 { + ret = mbedtls_mpi_copy(Z, X); + (*Z).private_s *= (*Y).private_s; + return ret; + } + + // Grow Z to result size early, avoid interim allocations + error_checked!(mbedtls_mpi_grow(Z, z_words)); + + // If either factor is over 2048 bits, we can't use the standard hardware multiplier + // (it assumes result is double longest factor, and result is max 4096 bits.) + // + // However, we can fail over to mod_mult for up to 4096 bits of result (modulo + // multiplication doesn't have the same restriction, so result is simply the + // number of bits in X plus number of bits in in Y.) + + if hw_words * 32 > SOC_RSA_MAX_BIT_LEN / 2 { + if z_words * 32 <= SOC_RSA_MAX_BIT_LEN { + // Note: It's possible to use mpi_mult_mpi_overlong + // for this case as well, but it's very slightly + // slower and requires a memory allocation. + // return mbedtls_mpi_mult_mpi_failover_mod_mult(X, A, B, z_words); + } + // else { + // Still too long for the hardware unit... + if y_words > x_words { + return mpi_mult_mpi_overlong(Z, X, Y, y_words, z_words); + } else { + return mpi_mult_mpi_overlong(Z, Y, X, x_words, z_words); + } + // } + } + + // Otherwise, we can use the (faster) multiply hardware unit + let mut rsa = Rsa::new(RSA::steal()); + nb::block!(rsa.ready()).unwrap(); + match hw_words * 4 { + U64::BYTES => { + // FIXME: This will soft-lock if using 64 for operand size + const OP_SIZE: usize = U128::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + let mut out = [0u8; OP_SIZE * 2]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U128::BYTES => { + const OP_SIZE: usize = U128::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + let mut out = [0u8; OP_SIZE * 2]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U256::BYTES => { + const OP_SIZE: usize = U256::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U384::BYTES => { + const OP_SIZE: usize = U384::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U512::BYTES => { + const OP_SIZE: usize = U512::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + // TODO: Is it normal to have hw_words * 4 not being a multiple of 32? + 68 => { + const OP_SIZE: usize = U576::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U1024::BYTES => { + const OP_SIZE: usize = U1024::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + // TODO: Is it normal to have hw_words * 4 not being a multiple of 32? + 132 | U1088::BYTES => { + const OP_SIZE: usize = U1088::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + U2048::BYTES => { + const OP_SIZE: usize = U2048::BYTES; + let mut operand_x = [0u8; OP_SIZE]; + let mut operand_y = [0u8; OP_SIZE]; + copy_bytes( + (*X).private_p as *const u8, + operand_x.as_mut_ptr(), + x_words * 4, + ); + copy_bytes( + (*Y).private_p as *const u8, + operand_y.as_mut_ptr(), + y_words * 4, + ); + let mut calc = RsaMultiplication::::new(&mut rsa, &operand_x); + calc.start_multiplication(&operand_y); + let mut out = [0u8; OP_SIZE * 2]; + calc.read_results(&mut out); + copy_bytes(out.as_ptr() as *mut u32, (*Z).private_p, z_words); + } + op => { + log::warn!("U64::BYTES {}", U64::BYTES); + log::warn!("U128::BYTES {}", U128::BYTES); + log::warn!("U192::BYTES {}", U192::BYTES); + log::warn!("U256::BYTES {}", U256::BYTES); + log::warn!("U320::BYTES {}", U320::BYTES); + log::warn!("U384::BYTES {}", U384::BYTES); + log::warn!("U448::BYTES {}", U448::BYTES); + log::warn!("U512::BYTES {}", U512::BYTES); + log::warn!("U576::BYTES {}", U576::BYTES); + log::warn!("U640::BYTES {}", U640::BYTES); + log::warn!("U704::BYTES {}", U704::BYTES); + log::warn!("U768::BYTES {}", U768::BYTES); + log::warn!("U832::BYTES {}", U832::BYTES); + log::warn!("U896::BYTES {}", U896::BYTES); + log::warn!("U960::BYTES {}", U960::BYTES); + log::warn!("U1024::BYTES {}", U1024::BYTES); + log::warn!("U1088::BYTES {}", U1088::BYTES); + log::warn!("U1152::BYTES {}", U1152::BYTES); + log::warn!("U1216::BYTES {}", U1216::BYTES); + log::warn!("U1280::BYTES {}", U1280::BYTES); + log::warn!("U1344::BYTES {}", U1344::BYTES); + log::warn!("U1408::BYTES {}", U1408::BYTES); + log::warn!("U1472::BYTES {}", U1472::BYTES); + log::warn!("U1536::BYTES {}", U1536::BYTES); + log::warn!("U1600::BYTES {}", U1600::BYTES); + log::warn!("U1664::BYTES {}", U1664::BYTES); + log::warn!("U1728::BYTES {}", U1728::BYTES); + log::warn!("U1792::BYTES {}", U1792::BYTES); + log::warn!("U1856::BYTES {}", U1856::BYTES); + log::warn!("U1920::BYTES {}", U1920::BYTES); + log::warn!("U1984::BYTES {}", U1984::BYTES); + log::warn!("U2048::BYTES {}", U2048::BYTES); + log::warn!("U2112::BYTES {}", U2112::BYTES); + log::warn!("U2176::BYTES {}", U2176::BYTES); + log::warn!("U2240::BYTES {}", U2240::BYTES); + log::warn!("U2304::BYTES {}", U2304::BYTES); + log::warn!("U2368::BYTES {}", U2368::BYTES); + log::warn!("U2432::BYTES {}", U2432::BYTES); + log::warn!("U2496::BYTES {}", U2496::BYTES); + log::warn!("U2560::BYTES {}", U2560::BYTES); + log::warn!("U2624::BYTES {}", U2624::BYTES); + log::warn!("U2688::BYTES {}", U2688::BYTES); + log::warn!("U2752::BYTES {}", U2752::BYTES); + log::warn!("U2816::BYTES {}", U2816::BYTES); + log::warn!("U2880::BYTES {}", U2880::BYTES); + log::warn!("U2944::BYTES {}", U2944::BYTES); + log::warn!("U3008::BYTES {}", U3008::BYTES); + log::warn!("U3072::BYTES {}", U3072::BYTES); + log::warn!("U3136::BYTES {}", U3136::BYTES); + log::warn!("U3200::BYTES {}", U3200::BYTES); + log::warn!("U3264::BYTES {}", U3264::BYTES); + log::warn!("U3328::BYTES {}", U3328::BYTES); + log::warn!("U3392::BYTES {}", U3392::BYTES); + log::warn!("U3456::BYTES {}", U3456::BYTES); + log::warn!("U3520::BYTES {}", U3520::BYTES); + log::warn!("U3584::BYTES {}", U3584::BYTES); + log::warn!("U3648::BYTES {}", U3648::BYTES); + log::warn!("U3712::BYTES {}", U3712::BYTES); + log::warn!("U3776::BYTES {}", U3776::BYTES); + log::warn!("U3840::BYTES {}", U3840::BYTES); + log::warn!("U3904::BYTES {}", U3904::BYTES); + log::warn!("U3968::BYTES {}", U3968::BYTES); + log::warn!("U4032::BYTES {}", U4032::BYTES); + log::warn!("U4096::BYTES {}", U4096::BYTES); + todo!("Implement operand: {}", op); + } + } + (*Z).private_s = (*X).private_s * (*Y).private_s; + + ret +} + +#[no_mangle] +pub unsafe extern "C" fn mbedtls_mpi_mul_int( + X: *mut mbedtls_mpi, + A: *const mbedtls_mpi, + b: mbedtls_mpi_uint, +) -> c_int { + let B: mbedtls_mpi = mbedtls_mpi { + private_s: 1, + private_n: 1, + private_p: [b].as_mut_ptr(), + }; + + mbedtls_mpi_mul_mpi(X, A, &B) +} diff --git a/esp-mbedtls/src/compat.rs b/esp-mbedtls/src/compat.rs index 6f4ef58..3b85263 100644 --- a/esp-mbedtls/src/compat.rs +++ b/esp-mbedtls/src/compat.rs @@ -1,9 +1,11 @@ use core::ffi::VaListImpl; use core::fmt::Write; +use crate::random; + #[no_mangle] -extern "C" fn putchar() { - todo!() +extern "C" fn putchar(c: crate::c_int) { + log::info!("{c}"); } #[no_mangle] @@ -120,8 +122,8 @@ extern "C" fn vsnprintf( } #[no_mangle] -extern "C" fn rand() { - todo!() +extern "C" fn rand() -> crate::c_ulong { + unsafe { random() } } pub struct StrBuf { diff --git a/esp-mbedtls/src/lib.rs b/esp-mbedtls/src/lib.rs index fe3c674..a98effc 100644 --- a/esp-mbedtls/src/lib.rs +++ b/esp-mbedtls/src/lib.rs @@ -20,15 +20,43 @@ pub use esp32s3_hal as hal; mod compat; +#[cfg(any(feature = "esp32s3", feature = "esp32c3"))] +mod bignum; + use core::ffi::CStr; use core::mem::size_of; +#[no_mangle] +pub unsafe extern "C" fn log_timestamp() { + log::info!("timestamp: {}", crate::hal::systimer::SystemTimer::now()); +} + use compat::StrBuf; use embedded_io::Read; use embedded_io::Write; use esp_mbedtls_sys::bindings::*; +/// +/// Re-export self-tests +pub use esp_mbedtls_sys::bindings::{ + // AES + mbedtls_aes_self_test, + // MD5 + mbedtls_md5_self_test, + // Bignum + mbedtls_mpi_self_test, + // RSA + mbedtls_rsa_self_test, + // SHA + mbedtls_sha1_self_test, + mbedtls_sha256_self_test, + mbedtls_sha384_self_test, + mbedtls_sha512_self_test, +}; use esp_mbedtls_sys::c_types::*; +#[cfg(not(feature = "esp32"))] +pub use esp_mbedtls_sys::bindings::mbedtls_sha224_self_test; + // these will come from esp-wifi (i.e. this can only be used together with esp-wifi) extern "C" { fn free(ptr: *const u8); diff --git a/examples/crypto_self_test.rs b/examples/crypto_self_test.rs new file mode 100644 index 0000000..849227d --- /dev/null +++ b/examples/crypto_self_test.rs @@ -0,0 +1,146 @@ +//! Run crypto self tests to ensure their functionnality +#![no_std] +#![no_main] + +#[doc(hidden)] +#[cfg(feature = "esp32")] +pub use esp32_hal as hal; +#[doc(hidden)] +#[cfg(feature = "esp32c3")] +pub use esp32c3_hal as hal; +#[doc(hidden)] +#[cfg(feature = "esp32s2")] +pub use esp32s2_hal as hal; +#[doc(hidden)] +#[cfg(feature = "esp32s3")] +pub use esp32s3_hal as hal; + +use esp_backtrace as _; +use esp_mbedtls::set_debug; +use esp_println::{logger::init_logger, println}; + +/// Only used for ROM functions +#[allow(unused_imports)] +use esp_wifi::{initialize, EspWifiInitFor}; +use hal::{clock::ClockControl, peripherals::Peripherals, prelude::*, systimer::SystemTimer, Rng}; + +#[entry] +fn main() -> ! { + init_logger(log::LevelFilter::Info); + + // Init ESP-WIFI heap for malloc + let peripherals = Peripherals::take(); + #[cfg(feature = "esp32")] + let mut system = peripherals.DPORT.split(); + #[cfg(not(feature = "esp32"))] + #[allow(unused_mut)] + let mut system = peripherals.SYSTEM.split(); + let clocks = ClockControl::max(system.clock_control).freeze(); + + #[cfg(feature = "esp32c3")] + let timer = hal::systimer::SystemTimer::new(peripherals.SYSTIMER).alarm0; + #[cfg(any(feature = "esp32", feature = "esp32s2", feature = "esp32s3"))] + let timer = hal::timer::TimerGroup::new(peripherals.TIMG1, &clocks).timer0; + let _ = initialize( + EspWifiInitFor::Wifi, + timer, + Rng::new(peripherals.RNG), + system.radio_clock_control, + &clocks, + ) + .unwrap(); + + set_debug(1); + + // println!("Testing AES"); + // unsafe { + // esp_mbedtls::mbedtls_aes_self_test(1i32); + // } + // println!("Testing MD5"); + // unsafe { + // esp_mbedtls::mbedtls_md5_self_test(1i32); + // } + println!("Testing RSA"); + unsafe { + esp_mbedtls::mbedtls_rsa_self_test(1i32); + } + // println!("Testing SHA"); + unsafe { + // esp_mbedtls::mbedtls_sha1_self_test(1i32); + // #[cfg(not(feature = "esp32"))] + // esp_mbedtls::mbedtls_sha224_self_test(1i32); + // esp_mbedtls::mbedtls_sha256_self_test(1i32); + // esp_mbedtls::mbedtls_sha384_self_test(1i32); + // esp_mbedtls::mbedtls_sha512_self_test(1i32); + + // HW Crypto: + // Testing RSA + // INFO - RSA key validation: + // INFO - passed + // PKCS#1 encryption : + // INFO - passed + // PKCS#1 decryption : + // INFO - passed + // INFO - PKCS#1 data sign : + // INFO - passed + // PKCS#1 sig. verify: + // INFO - passed + // INFO - 10 + // INFO - pre_cal 16377170 + // INFO - MPI test #1 (mul_mpi): + // INFO - passed + // INFO - MPI test #2 (div_mpi): + // INFO - passed + // INFO - MPI test #3 (exp_mod): + // INFO - passed + // INFO - MPI test #4 (inv_mod): + // INFO - passed + // INFO - MPI test #5 (simple gcd): + // INFO - passed + // INFO - 10 + // INFO - post_cal 17338357 + // Took 961187 cycles + // Done + + // SW Crypto: + // Testing RSA + // INFO - RSA key validation: + // INFO - passed + // PKCS#1 encryption : + // INFO - passed + // PKCS#1 decryption : + // INFO - passed + // INFO - PKCS#1 data sign : + // INFO - passed + // PKCS#1 sig. verify: + // INFO - passed + // INFO - 10 + // INFO - pre_cal 19067376 + // INFO - MPI test #1 (mul_mpi): + // INFO - passed + // INFO - MPI test #2 (div_mpi): + // INFO - passed + // INFO - MPI test #3 (exp_mod): + // INFO - passed + // INFO - MPI test #4 (inv_mod): + // INFO - passed + // INFO - MPI test #5 (simple gcd): + // INFO - passed + // INFO - 10 + // INFO - post_cal 20393146 + // Took 1325770 cycles + // Done + + let pre_calc = SystemTimer::now(); + log::info!("pre_cal {}", pre_calc); + esp_mbedtls::mbedtls_mpi_self_test(1i32); + let post_calc = SystemTimer::now(); + let hw_time = post_calc - pre_calc; + log::info!("post_cal {}", post_calc); + println!("Took {} cycles", hw_time); + } + + println!("Done"); + + loop {} +} diff --git a/mbedtls b/mbedtls index 1873d3b..cadbbd9 160000 --- a/mbedtls +++ b/mbedtls @@ -1 +1 @@ -Subproject commit 1873d3bfc2da771672bd8e7e8f41f57e0af77f33 +Subproject commit cadbbd91bb15c64e7bd4e8490010ddb78eed2121