Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/top matches #5

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 92 additions & 57 deletions examples/diarize.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
cargo run --example diarize
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example diarize motivation.wav
*/

use eyre::{bail, Result};
Expand All @@ -11,9 +12,73 @@ use sherpa_rs::{
};
use std::io::Cursor;

fn get_speaker_name(
embedding_manager: &mut embedding_manager::EmbeddingManager,
embedding: &mut [f32],
speaker_counter: &mut i32,
max_speakers: i32,
) -> String {
let mut name = String::from("unknown");

if *speaker_counter == 0 {
name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), embedding).unwrap();
*speaker_counter += 1;
} else if *speaker_counter <= max_speakers {
if let Some(search_result) = embedding_manager.search(embedding, 0.5) {
name = search_result;
} else {
name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), embedding).unwrap();
*speaker_counter += 1;
}
} else {
let matches = embedding_manager.get_best_matches(embedding, 0.2, *speaker_counter);
if let Some(name_match) = matches.first().map(|m| m.name.clone()) {
name = name_match;
}
}

name
}

fn process_speech_segment(
vad: &mut Vad,
sample_rate: i32,
mut embedding_manager: &mut embedding_manager::EmbeddingManager,
extractor: &mut speaker_id::EmbeddingExtractor,
speaker_counter: &mut i32,
max_speakers: i32,
) -> Result<()> {
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding = extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = get_speaker_name(
&mut embedding_manager,
&mut embedding,
speaker_counter,
max_speakers,
);
println!(
"({}) start={}s end={}s",
name,
start_sec,
start_sec + duration_sec
);
vad.pop();
}
Ok(())
}

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/motivation.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;
let max_speakers = 2;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down Expand Up @@ -43,14 +108,14 @@ fn main() -> Result<()> {
let mut embedding_manager =
embedding_manager::EmbeddingManager::new(extractor.embedding_size.try_into().unwrap()); // Assuming dimension 512 for embeddings

let mut speaker_counter = 0;
let mut speaker_counter = 1;

let vad_model = "silero_vad.onnx".into();
let window_size: usize = 512;
let config = VadConfig::new(
vad_model,
0.4,
0.4,
0.5,
0.5,
0.5,
sample_rate,
window_size.try_into().unwrap(),
Expand All @@ -66,61 +131,31 @@ fn main() -> Result<()> {
vad.accept_waveform(window.to_vec()); // Convert slice to Vec
if vad.is_speech() {
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding =
extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = if let Some(speaker_name) = embedding_manager.search(&embedding, 0.45) {
speaker_name
} else {
// Register a new speaker and add the embedding
let name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), &mut embedding)?;

speaker_counter += 1;
name
};
println!(
"({}) start={}s end={}s",
name,
start_sec,
start_sec + duration_sec
);
vad.pop();
process_speech_segment(
&mut vad,
sample_rate,
&mut embedding_manager,
&mut extractor,
&mut speaker_counter,
max_speakers,
)?;
}
}

index += window_size;
}

if index < samples.len() {
let remaining_samples = &samples[index..];
vad.accept_waveform(remaining_samples.to_vec());
while !vad.is_empty() {
let segment = vad.front();
let start_sec = (segment.start as f32) / sample_rate as f32;
let duration_sec = (segment.samples.len() as f32) / sample_rate as f32;

// Compute the speaker embedding
let mut embedding =
extractor.compute_speaker_embedding(sample_rate, segment.samples)?;

let name = if let Some(speaker_name) = embedding_manager.search(&embedding, 0.45) {
speaker_name
} else {
// Register a new speaker and add the embedding
let name = format!("speaker {}", speaker_counter);
embedding_manager.add(name.clone(), &mut embedding)?;

speaker_counter += 1;
name
};
println!("({}) start={}s duration={}s", name, start_sec, duration_sec);
vad.pop();
}
vad.flush();
// process reamaining
while !vad.is_empty() {
process_speech_segment(
&mut vad,
sample_rate,
&mut embedding_manager,
&mut extractor,
&mut speaker_counter,
max_speakers,
)?;
}

Ok(())
}
4 changes: 2 additions & 2 deletions examples/diarize_whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-o
wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/sam_altman.wav -O samples/sam_altman.wav
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/sam_altman.wav -O sam_altman.wav
cargo run --example diarize_whisper
*/

Expand Down Expand Up @@ -34,7 +34,7 @@ fn read_audio_file(path: &str) -> Result<(i32, Vec<f32>)> {

fn main() -> Result<()> {
// Read audio data from the file
let (sample_rate, mut samples) = read_audio_file("samples/sam_altman.wav")?;
let (sample_rate, mut samples) = read_audio_file("sam_altman.wav")?;

// Pad with 3 seconds of slience so vad will able to detect stop
for _ in 0..3 * sample_rate {
Expand Down
7 changes: 4 additions & 3 deletions examples/language_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
cargo run --example language_id
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/16hz_mono_pcm_s16le.wav -O 16hz_mono_pcm_s16le.wav
cargo run --example language_id 16hz_mono_pcm_s16le.wav
*/

use eyre::{bail, Result};
use sherpa_rs::language_id;
use std::io::Cursor;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/16hz_mono_pcm_s16le.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down
7 changes: 4 additions & 3 deletions examples/speaker_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
cargo run --example speaker_embedding
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/16hz_mono_pcm_s16le.wav -O 16hz_mono_pcm_s16le.wav
cargo run --example speaker_embedding 16hz_mono_pcm_s16le.wav
*/

use eyre::{bail, Result};
Expand All @@ -9,8 +10,8 @@ use std::io::Cursor;
use std::path::PathBuf;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/16hz_mono_pcm_s16le.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

// Use Cursor to create a reader from the byte slice
let cursor = Cursor::new(audio_data);
Expand Down
9 changes: 3 additions & 6 deletions examples/speaker_id.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/nemo_en_speakerverification_speakernet.onnx
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/biden.wav -O biden.wav
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/obama.wav -O obama.wav
cargo run --example speaker_id
*/
use eyre::{bail, Result};
Expand Down Expand Up @@ -29,12 +31,7 @@ fn main() -> Result<()> {
env_logger::init();

// Define paths to the audio files
let audio_files = vec![
"samples/obama.wav",
"samples/trump.wav",
"samples/biden.wav",
"samples/biden1.wav",
];
let audio_files = vec!["samples/obama.wav", "biden.wav"];

// Create the extractor configuration and extractor
let mut model_path = PathBuf::from(std::env::current_dir()?);
Expand Down
3 changes: 2 additions & 1 deletion examples/transcribe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example transcribe
*/

Expand All @@ -26,7 +27,7 @@ fn read_audio_file(path: &str) -> Result<(i32, Vec<f32>)> {
}

fn main() -> Result<()> {
let (sample_rate, samples) = read_audio_file("samples/motivation.wav")?;
let (sample_rate, samples) = read_audio_file("motivation.wav")?;

// Check if the sample rate is 16000
if sample_rate != 16000 {
Expand Down
5 changes: 3 additions & 2 deletions examples/vad_segment.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
/*
wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example vad_segment
*/
use eyre::{bail, Result};
use sherpa_rs::vad::{Vad, VadConfig};
use std::io::Cursor;

fn main() -> Result<()> {
// Read audio data from the file
let audio_data: &[u8] = include_bytes!("../samples/motivation.wav");
let file_path = std::env::args().nth(1).expect("Missing file path argument");
let audio_data = std::fs::read(file_path)?;

let cursor = Cursor::new(audio_data);
let mut reader = hound::WavReader::new(cursor)?;
Expand Down
Binary file removed samples/16hz_mono_pcm_s16le.wav
Binary file not shown.
Binary file removed samples/biden.wav
Binary file not shown.
Binary file removed samples/biden1.wav
Binary file not shown.
Binary file removed samples/motivation.wav
Binary file not shown.
Binary file removed samples/obama.wav
Binary file not shown.
Binary file removed samples/trump.wav
Binary file not shown.
41 changes: 40 additions & 1 deletion src/embedding_manager.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use eyre::{bail, Result};
use std::ffi::{CStr, CString};

#[derive(Debug)]
use crate::cstr_to_string;

#[derive(Debug, Clone)]
pub struct EmbeddingManager {
pub(crate) manager: *const sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManager,
}

#[derive(Debug, Clone)]
pub struct SpeakerMatch {
pub name: String,
pub score: f32,
}

impl EmbeddingManager {
pub fn new(dimension: i32) -> Self {
unsafe {
Expand All @@ -29,6 +37,37 @@ impl EmbeddingManager {
}
}

pub fn get_best_matches(
&mut self,
embedding: &[f32],
threshold: f32,
n: i32,
) -> Vec<SpeakerMatch> {
unsafe {
let result_ptr = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
self.manager,
embedding.to_owned().as_mut_ptr(),
threshold,
n,
);
if result_ptr.is_null() {
return Vec::new();
}
let result = result_ptr.read();

let matches_c = std::slice::from_raw_parts(result.matches, result.count as usize);
let mut matches: Vec<SpeakerMatch> = Vec::new();
for i in 0..result.count {
let match_c = matches_c[i as usize];
let name = cstr_to_string!(match_c.name);
let score = match_c.score;
matches.push(SpeakerMatch { name, score });
}
sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(result_ptr);
matches
}
}

pub fn add(&mut self, name: String, embedding: &mut [f32]) -> Result<()> {
let name_cstr = CString::new(name.clone())?;

Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ macro_rules! cstr {
CString::new($s).expect("Failed to create CString")
};
}

#[macro_export]
macro_rules! cstr_to_string {
($ptr:expr) => {
std::ffi::CStr::from_ptr($ptr).to_string_lossy().to_string()
};
}
6 changes: 6 additions & 0 deletions src/vad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ impl Vad {
}
}

pub fn flush(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxVoiceActivityDetectorFlush(self.vad);
}
}

pub fn accept_waveform(&mut self, mut samples: Vec<f32>) {
let samples_ptr = samples.as_mut_ptr();
let samples_length = samples.len();
Expand Down
Loading