Skip to content

Commit

Permalink
Offline recognizer fixes (#80)
Browse files Browse the repository at this point in the history
* Change transcribe to take an f32 slice

Taking an owned Vec is unnecessary since the ownership of the samples
array doesn't extend past the function.

* Reconstruct the recognizer result more faithfully

Return timestamps and tokens if the underlying FFI result contains it.

* Make Whisper and Moonshine configs clonable
  • Loading branch information
vlovich authored Jan 22, 2025
1 parent 93e1cfa commit 2643f13
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 34 deletions.
40 changes: 40 additions & 0 deletions crates/sherpa-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ mod utils;
#[cfg(feature = "tts")]
pub mod tts;

use std::ffi::CStr;

#[cfg(feature = "sys")]
pub use sherpa_rs_sys;

use eyre::{bail, Result};
use utils::cstr_to_string;

pub fn get_default_provider() -> String {
"cpu".into()
Expand Down Expand Up @@ -81,6 +84,43 @@ pub struct OnnxConfig {
pub num_threads: i32,
}

#[derive(Debug, Clone)]
pub struct OfflineRecognizerResult {
pub lang: String,
pub text: String,
pub timestamps: Vec<f32>,
pub tokens: Vec<String>,
}

impl OfflineRecognizerResult {
fn new(result: &sherpa_rs_sys::SherpaOnnxOfflineRecognizerResult) -> Self {
let lang = cstr_to_string(result.lang);
let text = cstr_to_string(result.text);
let count = result.count.try_into().unwrap();
let timestamps = if result.timestamps.is_null() {
Vec::new()
} else {
unsafe { std::slice::from_raw_parts(result.timestamps, count).to_vec() }
};
let mut tokens = Vec::with_capacity(count);
let mut next_token = result.tokens;

for _ in 0..count {
let token = unsafe { CStr::from_ptr(next_token) };
tokens.push(token.to_string_lossy().into_owned());
next_token = next_token
.wrapping_byte_offset(token.to_bytes_with_nul().len().try_into().unwrap());
}

Self {
lang,
text,
timestamps,
tokens,
}
}
}

impl Default for OnnxConfig {
fn default() -> Self {
Self {
Expand Down
18 changes: 5 additions & 13 deletions crates/sherpa-rs/src/moonshine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use crate::{get_default_provider, utils::RawCStr};
use eyre::{bail, Result};
use std::ptr::null;

Expand All @@ -10,13 +7,9 @@ pub struct MoonshineRecognizer {
recognizer: *const sherpa_rs_sys::SherpaOnnxOfflineRecognizer,
}

#[derive(Debug)]
pub struct MoonshineRecognizerResult {
pub text: String,
// pub timestamps: Vec<f32>,
}
pub type MoonshineRecognizerResult = super::OfflineRecognizerResult;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct MoonshineConfig {
pub preprocessor: String,

Expand Down Expand Up @@ -128,7 +121,7 @@ impl MoonshineRecognizer {
Ok(Self { recognizer })
}

pub fn transcribe(&mut self, sample_rate: u32, samples: Vec<f32>) -> MoonshineRecognizerResult {
pub fn transcribe(&mut self, sample_rate: u32, samples: &[f32]) -> MoonshineRecognizerResult {
unsafe {
let stream = sherpa_rs_sys::SherpaOnnxCreateOfflineStream(self.recognizer);
sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline(
Expand All @@ -140,8 +133,7 @@ impl MoonshineRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text as _);
let result = MoonshineRecognizerResult { text };
let result = MoonshineRecognizerResult::new(&raw_result);
// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
Expand Down
2 changes: 1 addition & 1 deletion crates/sherpa-rs/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl Drop for RawCStr {
}
}

pub fn cstr_to_string(ptr: *mut c_char) -> String {
pub fn cstr_to_string(ptr: *const c_char) -> String {
unsafe {
if ptr.is_null() {
String::new()
Expand Down
22 changes: 6 additions & 16 deletions crates/sherpa-rs/src/whisper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
};
use crate::{get_default_provider, utils::RawCStr};
use eyre::{bail, Result};
use std::ptr::null;

Expand All @@ -10,13 +7,9 @@ pub struct WhisperRecognizer {
recognizer: *const sherpa_rs_sys::SherpaOnnxOfflineRecognizer,
}

#[derive(Debug)]
pub struct WhisperRecognizerResult {
pub text: String,
// pub timestamps: Vec<f32>,
}
pub type WhisperRecognizerResult = super::OfflineRecognizerResult;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct WhisperConfig {
pub decoder: String,
pub encoder: String,
Expand Down Expand Up @@ -135,7 +128,7 @@ impl WhisperRecognizer {
Ok(Self { recognizer })
}

pub fn transcribe(&mut self, sample_rate: u32, samples: Vec<f32>) -> WhisperRecognizerResult {
pub fn transcribe(&mut self, sample_rate: u32, samples: &[f32]) -> WhisperRecognizerResult {
unsafe {
let stream = sherpa_rs_sys::SherpaOnnxCreateOfflineStream(self.recognizer);
sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline(
Expand All @@ -147,10 +140,7 @@ impl WhisperRecognizer {
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text as _);
// let timestamps: &[f32] =
// std::slice::from_raw_parts(raw_result.timestamps, raw_result.count as usize);
let result = WhisperRecognizerResult { text };
let result = WhisperRecognizerResult::new(&raw_result);
// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
Expand Down Expand Up @@ -200,7 +190,7 @@ mod tests {
let mut recognizer = WhisperRecognizer::new(config).unwrap();

let start_t = Instant::now();
let result = recognizer.transcribe(sample_rate, samples);
let result = recognizer.transcribe(sample_rate, &samples);
println!("{:?}", result);
println!("Time taken for transcription: {:?}", start_t.elapsed());
}
Expand Down
2 changes: 1 addition & 1 deletion examples/moonshine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn main() {
let mut recognizer = MoonshineRecognizer::new(config).unwrap();

let start_t = std::time::Instant::now();
let result = recognizer.transcribe(sample_rate, samples);
let result = recognizer.transcribe(sample_rate, &samples);
println!("✅ Text: {}", result.text);
println!("⏱️ Time taken for transcription: {:?}", start_t.elapsed());
}
4 changes: 2 additions & 2 deletions examples/vad_whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn main() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;
let transcript = recognizer.transcribe(sample_rate, segment.samples.clone());
let transcript = recognizer.transcribe(sample_rate, &segment.samples);

// Compute the speaker embedding
let mut embedding = extractor
Expand Down Expand Up @@ -96,7 +96,7 @@ fn main() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;
let transcript = recognizer.transcribe(sample_rate, segment.samples.clone());
let transcript = recognizer.transcribe(sample_rate, &segment.samples);

// Compute the speaker embedding
let mut embedding = extractor
Expand Down
2 changes: 1 addition & 1 deletion examples/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn main() {
let mut recognizer = WhisperRecognizer::new(config).unwrap();

let start_t = std::time::Instant::now();
let result = recognizer.transcribe(sample_rate, samples);
let result = recognizer.transcribe(sample_rate, &samples);
println!("✅ Text: {}", result.text);
println!("⏱️ Time taken for transcription: {:?}", start_t.elapsed());
}

0 comments on commit 2643f13

Please sign in to comment.