From 93e1cfa929de7acb4a3ff042d12f08fb5f4abc21 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Sun, 19 Jan 2025 18:51:36 +0200 Subject: [PATCH] refactor c strings type mapping (#78) * refactor c strings type mapping * fmt * bump: update sherpa-rs and sherpa-rs-sys versions to 0.6.5 --- Cargo.lock | 4 +-- crates/sherpa-rs-sys/Cargo.toml | 2 +- crates/sherpa-rs/Cargo.toml | 4 +-- crates/sherpa-rs/src/audio_tag.rs | 2 +- crates/sherpa-rs/src/embedding_manager.rs | 8 +++--- crates/sherpa-rs/src/keyword_spot.rs | 4 +-- crates/sherpa-rs/src/language_id.rs | 4 +-- crates/sherpa-rs/src/moonshine.rs | 4 +-- crates/sherpa-rs/src/punctuate.rs | 4 +-- crates/sherpa-rs/src/tts/kokoro.rs | 2 +- crates/sherpa-rs/src/tts/matcha.rs | 2 +- crates/sherpa-rs/src/utils.rs | 34 +++-------------------- crates/sherpa-rs/src/whisper.rs | 4 +-- crates/sherpa-rs/src/zipformer.rs | 4 +-- 14 files changed, 28 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d53d1ec..edb104e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -731,7 +731,7 @@ dependencies = [ [[package]] name = "sherpa-rs" -version = "0.6.4" +version = "0.6.5" dependencies = [ "clap", "eyre", @@ -742,7 +742,7 @@ dependencies = [ [[package]] name = "sherpa-rs-sys" -version = "0.6.4" +version = "0.6.5" dependencies = [ "bindgen", "bzip2", diff --git a/crates/sherpa-rs-sys/Cargo.toml b/crates/sherpa-rs-sys/Cargo.toml index 713993d..7f4aba9 100644 --- a/crates/sherpa-rs-sys/Cargo.toml +++ b/crates/sherpa-rs-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sherpa-rs-sys" -version = "0.6.4" +version = "0.6.5" edition = "2021" authors = ["thewh1teagle"] homepage = "https://github.com/thewh1teagle/sherpa-rs" diff --git a/crates/sherpa-rs/Cargo.toml b/crates/sherpa-rs/Cargo.toml index a7c2418..baf69bf 100644 --- a/crates/sherpa-rs/Cargo.toml +++ b/crates/sherpa-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sherpa-rs" -version = "0.6.4" +version = "0.6.5" edition = "2021" authors = ["thewh1teagle"] license = "MIT" @@ -21,7 +21,7 @@ crate-type = ["cdylib", "rlib"] [dependencies] eyre = "0.6.12" hound = { version = "3.5.1" } -sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.4", default-features = false } +sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.5", default-features = false } tracing = "0.1.40" [dev-dependencies] diff --git a/crates/sherpa-rs/src/audio_tag.rs b/crates/sherpa-rs/src/audio_tag.rs index a6dd120..8832ef7 100644 --- a/crates/sherpa-rs/src/audio_tag.rs +++ b/crates/sherpa-rs/src/audio_tag.rs @@ -73,7 +73,7 @@ impl AudioTag { for i in 0..self.config.top_k { let event = *results.add(i.try_into().unwrap()); - let event_name = cstr_to_string((*event).name); + let event_name = cstr_to_string((*event).name as _); events.push(event_name); } diff --git a/crates/sherpa-rs/src/embedding_manager.rs b/crates/sherpa-rs/src/embedding_manager.rs index 3fb17b6..e5a9df4 100644 --- a/crates/sherpa-rs/src/embedding_manager.rs +++ b/crates/sherpa-rs/src/embedding_manager.rs @@ -30,7 +30,7 @@ impl EmbeddingManager { if name.is_null() { return None; } - let name = cstr_to_string(name); + let name = cstr_to_string(name as _); Some(name) } } @@ -57,7 +57,7 @@ impl EmbeddingManager { let mut matches: Vec = Vec::new(); for i in 0..result.count { let match_c = matches_c[i as usize]; - let name = cstr_to_string(match_c.name); + let name = cstr_to_string(match_c.name as _); let score = match_c.score; matches.push(SpeakerMatch { name, score }); } @@ -76,7 +76,7 @@ impl EmbeddingManager { ); if status.is_negative() { - bail!("Failed to register {}", name) + bail!("Failed to register {}", name); } Ok(()) } @@ -90,6 +90,6 @@ impl Drop for EmbeddingManager { fn drop(&mut self) { unsafe { sherpa_rs_sys::SherpaOnnxDestroySpeakerEmbeddingManager(self.manager); - }; + } } } diff --git a/crates/sherpa-rs/src/keyword_spot.rs b/crates/sherpa-rs/src/keyword_spot.rs index 5d78aaa..1dbab05 100644 --- a/crates/sherpa-rs/src/keyword_spot.rs +++ b/crates/sherpa-rs/src/keyword_spot.rs @@ -109,7 +109,7 @@ impl KeywordSpot { let spotter = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordSpotter(&sherpa_config) }; if spotter.is_null() { - bail!("Failed to create keyword spotter") + bail!("Failed to create keyword spotter"); } let stream = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordStream(spotter) }; if stream.is_null() { @@ -139,7 +139,7 @@ impl KeywordSpot { let result_ptr = sherpa_rs_sys::SherpaOnnxGetKeywordResult(self.spotter, self.stream); let mut keyword = None; if !result_ptr.is_null() { - let decoded_keyword = cstr_to_string((*result_ptr).keyword); + let decoded_keyword = cstr_to_string((*result_ptr).keyword as _); if !decoded_keyword.is_empty() { keyword = Some(decoded_keyword); } diff --git a/crates/sherpa-rs/src/language_id.rs b/crates/sherpa-rs/src/language_id.rs index 7f50568..f60e3ca 100644 --- a/crates/sherpa-rs/src/language_id.rs +++ b/crates/sherpa-rs/src/language_id.rs @@ -56,10 +56,10 @@ impl SpokenLanguageId { let language_result_ptr = sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCompute(self.slid, stream); if language_result_ptr.is_null() || (*language_result_ptr).lang.is_null() { - bail!("language ptr is null") + bail!("language ptr is null"); } let language_ptr = (*language_result_ptr).lang; - let language = cstr_to_string(language_ptr); + let language = cstr_to_string(language_ptr as _); // Free sherpa_rs_sys::SherpaOnnxDestroySpokenLanguageIdentificationResult(language_result_ptr); sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream); diff --git a/crates/sherpa-rs/src/moonshine.rs b/crates/sherpa-rs/src/moonshine.rs index df6b66d..4e89f1d 100644 --- a/crates/sherpa-rs/src/moonshine.rs +++ b/crates/sherpa-rs/src/moonshine.rs @@ -122,7 +122,7 @@ impl MoonshineRecognizer { let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) }; if recognizer.is_null() { - bail!("Failed to create recognizer") + bail!("Failed to create recognizer"); } Ok(Self { recognizer }) @@ -140,7 +140,7 @@ impl MoonshineRecognizer { sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream); let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream); let raw_result = result_ptr.read(); - let text = cstr_to_string(raw_result.text); + let text = cstr_to_string(raw_result.text as _); let result = MoonshineRecognizerResult { text }; // Free sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr); diff --git a/crates/sherpa-rs/src/punctuate.rs b/crates/sherpa-rs/src/punctuate.rs index 5bdf895..b425a46 100644 --- a/crates/sherpa-rs/src/punctuate.rs +++ b/crates/sherpa-rs/src/punctuate.rs @@ -39,7 +39,7 @@ impl Punctuation { unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config) }; if audio_punctuation.is_null() { - bail!("Failed to create audio punctuation") + bail!("Failed to create audio punctuation"); } Ok(Self { audio_punctuation }) } @@ -51,7 +51,7 @@ impl Punctuation { self.audio_punctuation, text.as_ptr(), ); - let text_with_punct = cstr_to_string(text_with_punct_ptr); + let text_with_punct = cstr_to_string(text_with_punct_ptr as _); sherpa_rs_sys::SherpaOfflinePunctuationFreeText(text_with_punct_ptr); text_with_punct } diff --git a/crates/sherpa-rs/src/tts/kokoro.rs b/crates/sherpa-rs/src/tts/kokoro.rs index 248988e..24e44df 100644 --- a/crates/sherpa-rs/src/tts/kokoro.rs +++ b/crates/sherpa-rs/src/tts/kokoro.rs @@ -48,7 +48,7 @@ impl KokoroTts { }, }; let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig { - max_num_sentences: 0, + max_num_sentences: config.common_config.max_num_sentences, model: model_config, rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()), rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()), diff --git a/crates/sherpa-rs/src/tts/matcha.rs b/crates/sherpa-rs/src/tts/matcha.rs index ffe9e6a..34ee3d8 100644 --- a/crates/sherpa-rs/src/tts/matcha.rs +++ b/crates/sherpa-rs/src/tts/matcha.rs @@ -60,7 +60,7 @@ impl MatchaTts { kokoro: mem::zeroed::<_>(), }; let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig { - max_num_sentences: 0, + max_num_sentences: config.common_config.max_num_sentences, model: model_config, rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()), rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()), diff --git a/crates/sherpa-rs/src/utils.rs b/crates/sherpa-rs/src/utils.rs index 7b6a8ac..40e5c9b 100644 --- a/crates/sherpa-rs/src/utils.rs +++ b/crates/sherpa-rs/src/utils.rs @@ -1,12 +1,8 @@ -use std::ffi::CString; +use std::ffi::{c_char, CString}; // Smart pointer for CString pub struct RawCStr { - #[cfg(target_os = "android")] - ptr: *mut u8, - - #[cfg(not(target_os = "android"))] - ptr: *mut i8, + ptr: *mut std::ffi::c_char, } impl RawCStr { @@ -23,13 +19,7 @@ impl RawCStr { /// This function only returns the raw pointer and does not transfer ownership. /// The pointer remains valid as long as the `CStr` instance exists. /// Be cautious not to deallocate or modify the pointer after using `CStr::new`. - #[cfg(target_os = "android")] - pub fn as_ptr(&self) -> *const u8 { - self.ptr as *const u8 - } - - #[cfg(not(target_os = "android"))] - pub fn as_ptr(&self) -> *const i8 { + pub fn as_ptr(&self) -> *const c_char { self.ptr } } @@ -38,29 +28,13 @@ impl Drop for RawCStr { fn drop(&mut self) { if !self.ptr.is_null() { unsafe { - #[cfg(target_os = "android")] - let _ = CString::from_raw(self.ptr as *mut u8); - - #[cfg(not(target_os = "android"))] let _ = CString::from_raw(self.ptr); } } } } -#[cfg(target_os = "android")] -pub fn cstr_to_string(ptr: *const u8) -> String { - unsafe { - if ptr.is_null() { - String::new() - } else { - std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned() - } - } -} - -#[cfg(not(target_os = "android"))] -pub fn cstr_to_string(ptr: *const i8) -> String { +pub fn cstr_to_string(ptr: *mut c_char) -> String { unsafe { if ptr.is_null() { String::new() diff --git a/crates/sherpa-rs/src/whisper.rs b/crates/sherpa-rs/src/whisper.rs index 31bae92..d3062a8 100644 --- a/crates/sherpa-rs/src/whisper.rs +++ b/crates/sherpa-rs/src/whisper.rs @@ -129,7 +129,7 @@ impl WhisperRecognizer { let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) }; if recognizer.is_null() { - bail!("Failed to create recognizer") + bail!("Failed to create recognizer"); } Ok(Self { recognizer }) @@ -147,7 +147,7 @@ impl WhisperRecognizer { sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream); let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream); let raw_result = result_ptr.read(); - let text = cstr_to_string(raw_result.text); + let text = cstr_to_string(raw_result.text as _); // let timestamps: &[f32] = // std::slice::from_raw_parts(raw_result.timestamps, raw_result.count as usize); let result = WhisperRecognizerResult { text }; diff --git a/crates/sherpa-rs/src/zipformer.rs b/crates/sherpa-rs/src/zipformer.rs index cac7c38..372d9c0 100644 --- a/crates/sherpa-rs/src/zipformer.rs +++ b/crates/sherpa-rs/src/zipformer.rs @@ -96,7 +96,7 @@ impl ZipFormer { unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&recognizer_config) }; if recognizer.is_null() { - bail!("Failed to create recognizer") + bail!("Failed to create recognizer"); } Ok(Self { recognizer }) } @@ -113,7 +113,7 @@ impl ZipFormer { sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream); let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream); let raw_result = result_ptr.read(); - let text = cstr_to_string(raw_result.text); + let text = cstr_to_string(raw_result.text as _); // Free sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);