Skip to content

Commit

Permalink
feat: implement Send and Sync traits for TTS and AudioTag structs
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jan 18, 2025
1 parent f59f4c1 commit b54a965
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 7 deletions.
16 changes: 15 additions & 1 deletion crates/sherpa-rs/src/audio_tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl AudioTag {
let audio_tag = unsafe { sherpa_rs_sys::SherpaOnnxCreateAudioTagging(&sherpa_config) };

if audio_tag.is_null() {
bail!("Failed to create audio tagging")
bail!("Failed to create audio tagging");
}
Ok(Self {
audio_tag,
Expand All @@ -64,6 +64,7 @@ impl AudioTag {
samples.as_ptr(),
samples.len() as i32,
);

let results = sherpa_rs_sys::SherpaOnnxAudioTaggingCompute(
self.audio_tag,
stream,
Expand All @@ -75,7 +76,20 @@ impl AudioTag {
let event_name = cstr_to_string((*event).name);
events.push(event_name);
}

sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
}
events
}
}

unsafe impl Send for AudioTag {}
unsafe impl Sync for AudioTag {}

impl Drop for AudioTag {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroyAudioTagging(self.audio_tag);
}
}
}
13 changes: 7 additions & 6 deletions crates/sherpa-rs/src/diarize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{get_default_provider, utils::RawCStr};
use eyre::{bail, Result};
use std::path::Path;
use std::{path::Path, ptr::null_mut};

#[derive(Debug)]
pub struct Diarize {
Expand All @@ -14,7 +14,7 @@ pub struct Segment {
pub speaker: i32,
}

type ProgressCallback = Box<dyn Fn(i32, i32) -> i32 + Send + 'static>;
type ProgressCallback = Box<dyn (Fn(i32, i32) -> i32) + Send + 'static>;

#[derive(Debug, Clone)]
pub struct DiarizeConfig {
Expand Down Expand Up @@ -85,7 +85,7 @@ impl Diarize {
let sd = unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflineSpeakerDiarization(&config) };

if sd.is_null() {
bail!("Failed to initialize offline speaker diarization")
bail!("Failed to initialize offline speaker diarization");
}
Ok(Self { sd })
}
Expand All @@ -103,7 +103,7 @@ impl Diarize {
let callback_ptr = callback_box
.as_mut()
.map(|b| b.as_mut() as *mut ProgressCallback as *mut std::ffi::c_void)
.unwrap_or(std::ptr::null_mut());
.unwrap_or(null_mut());

let result = sherpa_rs_sys::SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
self.sd,
Expand All @@ -123,8 +123,9 @@ impl Diarize {
sherpa_rs_sys::SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result);

if !segments_ptr.is_null() && num_segments > 0 {
let segments_result: &[sherpa_rs_sys::SherpaOnnxOfflineSpeakerDiarizationSegment] =
std::slice::from_raw_parts(segments_ptr, num_segments as usize);
let segments_result: &[
sherpa_rs_sys::SherpaOnnxOfflineSpeakerDiarizationSegment
] = std::slice::from_raw_parts(segments_ptr, num_segments as usize);

for segment in segments_result {
// Use segment here
Expand Down
11 changes: 11 additions & 0 deletions crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ impl KokoroTts {
unsafe { super::create(self.tts, text, sid, speed) }
}
}

unsafe impl Send for KokoroTts {}
unsafe impl Sync for KokoroTts {}

impl Drop for KokoroTts {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroyOfflineTts(self.tts);
}
}
}
11 changes: 11 additions & 0 deletions crates/sherpa-rs/src/tts/matcha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ impl MatchaTts {
unsafe { super::create(self.tts, text, sid, speed) }
}
}

unsafe impl Send for MatchaTts {}
unsafe impl Sync for MatchaTts {}

impl Drop for MatchaTts {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroyOfflineTts(self.tts);
}
}
}
11 changes: 11 additions & 0 deletions crates/sherpa-rs/src/tts/vits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ impl VitsTts {
unsafe { super::create(self.tts, text, sid, speed) }
}
}

unsafe impl Send for VitsTts {}
unsafe impl Sync for VitsTts {}

impl Drop for VitsTts {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroyOfflineTts(self.tts);
}
}
}

0 comments on commit b54a965

Please sign in to comment.