Skip to content

Commit

Permalink
refactor c strings type mapping (#78)
Browse files Browse the repository at this point in the history
* refactor c strings type mapping

* fmt

* bump: update sherpa-rs and sherpa-rs-sys versions to 0.6.5
  • Loading branch information
thewh1teagle authored Jan 19, 2025
1 parent 548e2f3 commit 93e1cfa
Show file tree
Hide file tree
Showing 14 changed files with 28 additions and 54 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/sherpa-rs-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sherpa-rs-sys"
version = "0.6.4"
version = "0.6.5"
edition = "2021"
authors = ["thewh1teagle"]
homepage = "https://github.com/thewh1teagle/sherpa-rs"
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sherpa-rs"
version = "0.6.4"
version = "0.6.5"
edition = "2021"
authors = ["thewh1teagle"]
license = "MIT"
Expand All @@ -21,7 +21,7 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
eyre = "0.6.12"
hound = { version = "3.5.1" }
sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.4", default-features = false }
sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.5", default-features = false }
tracing = "0.1.40"

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/audio_tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl AudioTag {

for i in 0..self.config.top_k {
let event = *results.add(i.try_into().unwrap());
let event_name = cstr_to_string((*event).name);
let event_name = cstr_to_string((*event).name as _);
events.push(event_name);
}

Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/embedding_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl EmbeddingManager {
if name.is_null() {
return None;
}
let name = cstr_to_string(name);
let name = cstr_to_string(name as _);
Some(name)
}
}
Expand All @@ -57,7 +57,7 @@ impl EmbeddingManager {
let mut matches: Vec<SpeakerMatch> = Vec::new();
for i in 0..result.count {
let match_c = matches_c[i as usize];
let name = cstr_to_string(match_c.name);
let name = cstr_to_string(match_c.name as _);
let score = match_c.score;
matches.push(SpeakerMatch { name, score });
}
Expand All @@ -76,7 +76,7 @@ impl EmbeddingManager {
);

if status.is_negative() {
bail!("Failed to register {}", name)
bail!("Failed to register {}", name);
}
Ok(())
}
Expand All @@ -90,6 +90,6 @@ impl Drop for EmbeddingManager {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroySpeakerEmbeddingManager(self.manager);
};
}
}
}
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/keyword_spot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl KeywordSpot {
let spotter = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordSpotter(&sherpa_config) };

if spotter.is_null() {
bail!("Failed to create keyword spotter")
bail!("Failed to create keyword spotter");
}
let stream = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordStream(spotter) };
if stream.is_null() {
Expand Down Expand Up @@ -139,7 +139,7 @@ impl KeywordSpot {
let result_ptr = sherpa_rs_sys::SherpaOnnxGetKeywordResult(self.spotter, self.stream);
let mut keyword = None;
if !result_ptr.is_null() {
let decoded_keyword = cstr_to_string((*result_ptr).keyword);
let decoded_keyword = cstr_to_string((*result_ptr).keyword as _);
if !decoded_keyword.is_empty() {
keyword = Some(decoded_keyword);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/language_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ impl SpokenLanguageId {
let language_result_ptr =
sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCompute(self.slid, stream);
if language_result_ptr.is_null() || (*language_result_ptr).lang.is_null() {
bail!("language ptr is null")
bail!("language ptr is null");
}
let language_ptr = (*language_result_ptr).lang;
let language = cstr_to_string(language_ptr);
let language = cstr_to_string(language_ptr as _);
// Free
sherpa_rs_sys::SherpaOnnxDestroySpokenLanguageIdentificationResult(language_result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/moonshine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl MoonshineRecognizer {
let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}

Ok(Self { recognizer })
Expand All @@ -140,7 +140,7 @@ impl MoonshineRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);
let result = MoonshineRecognizerResult { text };
// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/punctuate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Punctuation {
unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config) };

if audio_punctuation.is_null() {
bail!("Failed to create audio punctuation")
bail!("Failed to create audio punctuation");
}
Ok(Self { audio_punctuation })
}
Expand All @@ -51,7 +51,7 @@ impl Punctuation {
self.audio_punctuation,
text.as_ptr(),
);
let text_with_punct = cstr_to_string(text_with_punct_ptr);
let text_with_punct = cstr_to_string(text_with_punct_ptr as _);
sherpa_rs_sys::SherpaOfflinePunctuationFreeText(text_with_punct_ptr);
text_with_punct
}
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl KokoroTts {
},
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/tts/matcha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl MatchaTts {
kokoro: mem::zeroed::<_>(),
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
34 changes: 4 additions & 30 deletions crates/sherpa-rs/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use std::ffi::CString;
use std::ffi::{c_char, CString};

// Smart pointer for CString
pub struct RawCStr {
#[cfg(target_os = "android")]
ptr: *mut u8,

#[cfg(not(target_os = "android"))]
ptr: *mut i8,
ptr: *mut std::ffi::c_char,
}

impl RawCStr {
Expand All @@ -23,13 +19,7 @@ impl RawCStr {
/// This function only returns the raw pointer and does not transfer ownership.
/// The pointer remains valid as long as the `CStr` instance exists.
/// Be cautious not to deallocate or modify the pointer after using `CStr::new`.
#[cfg(target_os = "android")]
pub fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}

#[cfg(not(target_os = "android"))]
pub fn as_ptr(&self) -> *const i8 {
pub fn as_ptr(&self) -> *const c_char {
self.ptr
}
}
Expand All @@ -38,29 +28,13 @@ impl Drop for RawCStr {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
#[cfg(target_os = "android")]
let _ = CString::from_raw(self.ptr as *mut u8);

#[cfg(not(target_os = "android"))]
let _ = CString::from_raw(self.ptr);
}
}
}
}

#[cfg(target_os = "android")]
pub fn cstr_to_string(ptr: *const u8) -> String {
unsafe {
if ptr.is_null() {
String::new()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
}

#[cfg(not(target_os = "android"))]
pub fn cstr_to_string(ptr: *const i8) -> String {
pub fn cstr_to_string(ptr: *mut c_char) -> String {
unsafe {
if ptr.is_null() {
String::new()
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl WhisperRecognizer {
let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}

Ok(Self { recognizer })
Expand All @@ -147,7 +147,7 @@ impl WhisperRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);
// let timestamps: &[f32] =
// std::slice::from_raw_parts(raw_result.timestamps, raw_result.count as usize);
let result = WhisperRecognizerResult { text };
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/zipformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl ZipFormer {
unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&recognizer_config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}
Ok(Self { recognizer })
}
Expand All @@ -113,7 +113,7 @@ impl ZipFormer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);

// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
Expand Down

0 comments on commit 93e1cfa

Please sign in to comment.