Skip to content

Commit

Permalink
refactor c strings type mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jan 19, 2025
1 parent 548e2f3 commit 5c770a2
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 125 deletions.
13 changes: 5 additions & 8 deletions crates/sherpa-rs/src/audio_tag.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use eyre::{bail, Result};
use eyre::{ bail, Result };

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use crate::{ get_default_provider, utils::{ cstr_to_string, RawCStr } };

#[derive(Debug, Default, Clone)]
pub struct AudioTagConfig {
Expand Down Expand Up @@ -62,18 +59,18 @@ impl AudioTag {
stream,
sample_rate as i32,
samples.as_ptr(),
samples.len() as i32,
samples.len() as i32
);

let results = sherpa_rs_sys::SherpaOnnxAudioTaggingCompute(
self.audio_tag,
stream,
self.config.top_k,
self.config.top_k
);

for i in 0..self.config.top_k {
let event = *results.add(i.try_into().unwrap());
let event_name = cstr_to_string((*event).name);
let event_name = cstr_to_string((*event).name as _);
events.push(event_name);
}

Expand Down
20 changes: 10 additions & 10 deletions crates/sherpa-rs/src/embedding_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::utils::{cstr_to_string, RawCStr};
use eyre::{bail, Result};
use crate::utils::{ cstr_to_string, RawCStr };
use eyre::{ bail, Result };

#[derive(Debug, Clone)]
pub struct EmbeddingManager {
Expand All @@ -25,12 +25,12 @@ impl EmbeddingManager {
let name = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerSearch(
self.manager,
embedding.to_owned().as_mut_ptr(),
threshold,
threshold
);
if name.is_null() {
return None;
}
let name = cstr_to_string(name);
let name = cstr_to_string(name as _);
Some(name)
}
}
Expand All @@ -39,14 +39,14 @@ impl EmbeddingManager {
&mut self,
embedding: &[f32],
threshold: f32,
n: i32,
n: i32
) -> Vec<SpeakerMatch> {
unsafe {
let result_ptr = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
self.manager,
embedding.to_owned().as_mut_ptr(),
threshold,
n,
n
);
if result_ptr.is_null() {
return Vec::new();
Expand All @@ -57,7 +57,7 @@ impl EmbeddingManager {
let mut matches: Vec<SpeakerMatch> = Vec::new();
for i in 0..result.count {
let match_c = matches_c[i as usize];
let name = cstr_to_string(match_c.name);
let name = cstr_to_string(match_c.name as _);
let score = match_c.score;
matches.push(SpeakerMatch { name, score });
}
Expand All @@ -72,11 +72,11 @@ impl EmbeddingManager {
let status = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerAdd(
self.manager,
name_c.as_ptr(),
embedding.as_mut_ptr(),
embedding.as_mut_ptr()
);

if status.is_negative() {
bail!("Failed to register {}", name)
bail!("Failed to register {}", name);
}
Ok(())
}
Expand All @@ -90,6 +90,6 @@ impl Drop for EmbeddingManager {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroySpeakerEmbeddingManager(self.manager);
};
}
}
}
15 changes: 6 additions & 9 deletions crates/sherpa-rs/src/keyword_spot.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::ptr::null;

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use eyre::{bail, Result};
use crate::{ get_default_provider, utils::{ cstr_to_string, RawCStr } };
use eyre::{ bail, Result };

#[derive(Debug, Clone)]
pub struct KeywordSpotConfig {
Expand Down Expand Up @@ -109,7 +106,7 @@ impl KeywordSpot {
let spotter = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordSpotter(&sherpa_config) };

if spotter.is_null() {
bail!("Failed to create keyword spotter")
bail!("Failed to create keyword spotter");
}
let stream = unsafe { sherpa_rs_sys::SherpaOnnxCreateKeywordStream(spotter) };
if stream.is_null() {
Expand All @@ -122,15 +119,15 @@ impl KeywordSpot {
pub fn extract_keyword(
&mut self,
samples: Vec<f32>,
sample_rate: u32,
sample_rate: u32
) -> Result<Option<String>> {
// Create keyword spotting stream
unsafe {
sherpa_rs_sys::SherpaOnnxOnlineStreamAcceptWaveform(
self.stream,
sample_rate as i32,
samples.as_ptr(),
samples.len() as i32,
samples.len() as i32
);
sherpa_rs_sys::SherpaOnnxOnlineStreamInputFinished(self.stream);
while sherpa_rs_sys::SherpaOnnxIsKeywordStreamReady(self.spotter, self.stream) == 1 {
Expand All @@ -139,7 +136,7 @@ impl KeywordSpot {
let result_ptr = sherpa_rs_sys::SherpaOnnxGetKeywordResult(self.spotter, self.stream);
let mut keyword = None;
if !result_ptr.is_null() {
let decoded_keyword = cstr_to_string((*result_ptr).keyword);
let decoded_keyword = cstr_to_string((*result_ptr).keyword as _);
if !decoded_keyword.is_empty() {
keyword = Some(decoded_keyword);
}
Expand Down
29 changes: 15 additions & 14 deletions crates/sherpa-rs/src/language_id.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use eyre::{bail, Result};
use crate::{ get_default_provider, utils::{ cstr_to_string, RawCStr } };
use eyre::{ bail, Result };

#[derive(Debug)]
pub struct SpokenLanguageId {
Expand Down Expand Up @@ -37,29 +34,33 @@ impl SpokenLanguageId {
provider: provider.as_ptr(),
whisper,
};
let slid =
unsafe { sherpa_rs_sys::SherpaOnnxCreateSpokenLanguageIdentification(&sherpa_config) };
let slid = unsafe {
sherpa_rs_sys::SherpaOnnxCreateSpokenLanguageIdentification(&sherpa_config)
};

Self { slid }
}

pub fn compute(&mut self, samples: Vec<f32>, sample_rate: u32) -> Result<String> {
unsafe {
let stream =
sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCreateOfflineStream(self.slid);
let stream = sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCreateOfflineStream(
self.slid
);
sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline(
stream,
sample_rate as i32,
samples.as_ptr(),
samples.len().try_into().unwrap(),
samples.len().try_into().unwrap()
);
let language_result_ptr = sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCompute(
self.slid,
stream
);
let language_result_ptr =
sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationCompute(self.slid, stream);
if language_result_ptr.is_null() || (*language_result_ptr).lang.is_null() {
bail!("language ptr is null")
bail!("language ptr is null");
}
let language_ptr = (*language_result_ptr).lang;
let language = cstr_to_string(language_ptr);
let language = cstr_to_string(language_ptr as _);
// Free
sherpa_rs_sys::SherpaOnnxDestroySpokenLanguageIdentificationResult(language_result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
Expand Down
13 changes: 5 additions & 8 deletions crates/sherpa-rs/src/moonshine.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use eyre::{bail, Result};
use crate::{ get_default_provider, utils::{ cstr_to_string, RawCStr } };
use eyre::{ bail, Result };
use std::ptr::null;

#[derive(Debug)]
Expand Down Expand Up @@ -122,7 +119,7 @@ impl MoonshineRecognizer {
let recognizer = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&config) };

if recognizer.is_null() {
bail!("Failed to create recognizer")
bail!("Failed to create recognizer");
}

Ok(Self { recognizer })
Expand All @@ -135,12 +132,12 @@ impl MoonshineRecognizer {
stream,
sample_rate as i32,
samples.as_ptr(),
samples.len().try_into().unwrap(),
samples.len().try_into().unwrap()
);
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text);
let text = cstr_to_string(raw_result.text as _);
let result = MoonshineRecognizerResult { text };
// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
Expand Down
34 changes: 18 additions & 16 deletions crates/sherpa-rs/src/punctuate.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use eyre::{bail, Result};
use eyre::{ bail, Result };

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use crate::{ get_default_provider, utils::{ cstr_to_string, RawCStr } };

#[derive(Debug, Default, Clone)]
pub struct PunctuationConfig {
Expand All @@ -20,12 +17,16 @@ pub struct Punctuation {
impl Punctuation {
pub fn new(config: PunctuationConfig) -> Result<Self> {
let model = RawCStr::new(&config.model);
let provider = RawCStr::new(&config.provider.unwrap_or(if cfg!(target_os = "macos") {
// TODO: sherpa-onnx/issues/1448
"cpu".into()
} else {
get_default_provider()
}));
let provider = RawCStr::new(
&config.provider.unwrap_or(
if cfg!(target_os = "macos") {
// TODO: sherpa-onnx/issues/1448
"cpu".into()
} else {
get_default_provider()
}
)
);

let sherpa_config = sherpa_rs_sys::SherpaOnnxOfflinePunctuationConfig {
model: sherpa_rs_sys::SherpaOnnxOfflinePunctuationModelConfig {
Expand All @@ -35,11 +36,12 @@ impl Punctuation {
provider: provider.as_ptr(),
},
};
let audio_punctuation =
unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config) };
let audio_punctuation = unsafe {
sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config)
};

if audio_punctuation.is_null() {
bail!("Failed to create audio punctuation")
bail!("Failed to create audio punctuation");
}
Ok(Self { audio_punctuation })
}
Expand All @@ -49,9 +51,9 @@ impl Punctuation {
unsafe {
let text_with_punct_ptr = sherpa_rs_sys::SherpaOfflinePunctuationAddPunct(
self.audio_punctuation,
text.as_ptr(),
text.as_ptr()
);
let text_with_punct = cstr_to_string(text_with_punct_ptr);
let text_with_punct = cstr_to_string(text_with_punct_ptr as _);
sherpa_rs_sys::SherpaOfflinePunctuationFreeText(text_with_punct_ptr);
text_with_punct
}
Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{mem, ptr::null};
use std::{ mem, ptr::null };

use crate::{utils::RawCStr, OnnxConfig};
use crate::{ utils::RawCStr, OnnxConfig };
use eyre::Result;
use sherpa_rs_sys;

use super::{CommonTtsConfig, TtsAudio};
use super::{ CommonTtsConfig, TtsAudio };

pub struct KokoroTts {
tts: *const sherpa_rs_sys::SherpaOnnxOfflineTts,
Expand Down Expand Up @@ -48,7 +48,7 @@ impl KokoroTts {
},
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/tts/matcha.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{mem, ptr::null};
use std::{ mem, ptr::null };

use crate::{utils::RawCStr, OnnxConfig};
use crate::{ utils::RawCStr, OnnxConfig };
use eyre::Result;
use sherpa_rs_sys;

use super::{CommonTtsConfig, TtsAudio};
use super::{ CommonTtsConfig, TtsAudio };

pub struct MatchaTts {
tts: *const sherpa_rs_sys::SherpaOnnxOfflineTts,
Expand Down Expand Up @@ -60,7 +60,7 @@ impl MatchaTts {
kokoro: mem::zeroed::<_>(),
};
let config = sherpa_rs_sys::SherpaOnnxOfflineTtsConfig {
max_num_sentences: 0,
max_num_sentences: config.common_config.max_num_sentences,
model: model_config,
rule_fars: tts_config.rule_fars.map(|v| v.as_ptr()).unwrap_or(null()),
rule_fsts: tts_config.rule_fsts.map(|v| v.as_ptr()).unwrap_or(null()),
Expand Down
Loading

0 comments on commit 5c770a2

Please sign in to comment.