Skip to content

Commit

Permalink
beat_detector: init structs BeatDetector and AudioInput
Browse files Browse the repository at this point in the history
The relatively complex AudioInput type (complex because large base
of the code need adoptions) is very nice as this way, users do not
necessarily need to create intermediate arrays with their data.

Thus, the copying into the internal buffer (after applying the
low pass filter), is the only copying needed. Performance, yay!
  • Loading branch information
phip1611 committed Apr 28, 2024
1 parent 8c9cd0a commit 5cd45cb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 69 deletions.
61 changes: 37 additions & 24 deletions src/audio_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,34 @@ impl AudioHistory {
}

/// Update the audio history with fresh samples. The audio samples are
/// expected to be in mono channel, i.e., no stereo interleaving
/// TODO: Update consume iterator and consume enum: Interleaved or Mono
pub fn update(&mut self, mono_samples: &[f32]) {
if mono_samples.len() >= self.audio_buffer.capacity() {
/// expected to be in mono channel format.
pub fn update<'a, I: Iterator<Item = f32>>(&mut self, mono_samples_iter: I) {
let mut len = 0;
mono_samples_iter
.inspect(|sample| {
debug_assert!(sample.is_finite());
debug_assert!(sample.abs() <= 1.0);
})
.for_each(|sample| {
self.audio_buffer.push(sample);
len += 1;
});

if len >= self.audio_buffer.capacity() {
log::warn!(
"Adding {} samples to the audio buffer that only has a capacity for {} samples.",
mono_samples.len(),
len,
self.audio_buffer.capacity()
);
#[cfg(test)]
std::eprintln!(
"WARN: AudioHistory::update: Adding {} samples to the audio buffer that only has a capacity for {} samples.",
mono_samples.len(),
len,
self.audio_buffer.capacity()
);
}

for &sample in mono_samples {
debug_assert!(sample.is_finite());
debug_assert!(sample.abs() <= 1.0);
self.audio_buffer.push(sample);
}
self.total_consumed_items += mono_samples.len();
self.total_consumed_items += len;
}

/// Get the passed time in seconds.
Expand Down Expand Up @@ -204,11 +209,11 @@ mod tests {
let mut hist = AudioHistory::new(2.0);
assert_eq!(hist.total_consumed_items, 0);

hist.update(&[0.0]);
hist.update([0.0].iter().copied());
assert_eq!(hist.total_consumed_items, 1);
assert_eq!(hist.passed_time(), Duration::from_secs_f32(0.5));

hist.update(&[0.0, 0.0]);
hist.update([0.0, 0.0].iter().copied());
assert_eq!(hist.total_consumed_items, 3);
assert_eq!(hist.passed_time(), Duration::from_secs_f32(1.5));
}
Expand All @@ -222,12 +227,12 @@ mod tests {
.map(|x| x as f32 / (DEFAULT_BUFFER_SIZE + 10) as f32)
.collect::<Vec<_>>();

hist.update(&test_data[0..10]);
hist.update(test_data[0..10].iter().copied());
assert_eq!(hist.index_to_sample_number(0), 0);
assert_eq!(hist.index_to_sample_number(10), 10);

// now the buffer is full, but no overflow yet
hist.update(&test_data[10..DEFAULT_BUFFER_SIZE]);
hist.update(test_data[10..DEFAULT_BUFFER_SIZE].iter().copied());
assert_eq!(hist.index_to_sample_number(0), 0);
assert_eq!(hist.index_to_sample_number(10), 10);
assert_eq!(
Expand All @@ -236,7 +241,11 @@ mod tests {
);

// now the buffer overflowed
hist.update(&test_data[DEFAULT_BUFFER_SIZE..DEFAULT_BUFFER_SIZE + 10]);
hist.update(
test_data[DEFAULT_BUFFER_SIZE..DEFAULT_BUFFER_SIZE + 10]
.iter()
.copied(),
);
assert_eq!(hist.index_to_sample_number(0), 10);
assert_eq!(hist.index_to_sample_number(10), 20);
assert_eq!(
Expand All @@ -255,17 +264,21 @@ mod tests {
.map(|x| x as f32 / (DEFAULT_BUFFER_SIZE + 10) as f32)
.collect::<Vec<_>>();

hist.update(&test_data[0..10]);
hist.update(test_data[0..10].iter().copied());
assert_eq!(hist.timestamp_of_index(0), Duration::from_secs_f32(0.0));
assert_eq!(hist.timestamp_of_index(10), Duration::from_secs_f32(5.0));

// now the buffer is full, but no overflow yet
hist.update(&test_data[10..DEFAULT_BUFFER_SIZE]);
hist.update(test_data[10..DEFAULT_BUFFER_SIZE].iter().copied());
assert_eq!(hist.timestamp_of_index(0), Duration::from_secs_f32(0.0));
assert_eq!(hist.timestamp_of_index(10), Duration::from_secs_f32(5.0));

// now the buffer overflowed
hist.update(&test_data[DEFAULT_BUFFER_SIZE..DEFAULT_BUFFER_SIZE + 10]);
hist.update(
test_data[DEFAULT_BUFFER_SIZE..DEFAULT_BUFFER_SIZE + 10]
.iter()
.copied(),
);
assert_eq!(hist.timestamp_of_index(0), Duration::from_secs_f32(5.0));
assert_eq!(hist.timestamp_of_index(10), Duration::from_secs_f32(10.0));
}
Expand All @@ -275,7 +288,7 @@ mod tests {
let (samples, header) = crate::test_utils::samples::sample1_long();

let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

assert_eq!(
(history.passed_time().as_secs_f32() * 1000.0).round() / 1000.0,
Expand All @@ -293,12 +306,12 @@ mod tests {
fn sample_info() {
let mut hist = AudioHistory::new(1.0);

hist.update(&[0.0]);
hist.update([0.0].iter().copied());
assert_eq!(
hist.index_to_sample_info(0).duration_behind,
Duration::from_secs(0)
);
hist.update(&[0.0]);
hist.update([0.0].iter().copied());
assert_eq!(
hist.index_to_sample_info(0).duration_behind,
Duration::from_secs(1)
Expand All @@ -308,7 +321,7 @@ mod tests {
Duration::from_secs(0)
);

hist.update(&[0.0].repeat(hist.data().capacity() * 2));
hist.update([0.0].repeat(hist.data().capacity() * 2).iter().copied());

let sample = hist.index_to_sample_info(0);
assert_eq!(
Expand Down
31 changes: 0 additions & 31 deletions src/audio_input.rs

This file was deleted.

109 changes: 109 additions & 0 deletions src/beat_detector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use crate::EnvelopeInfo;
use crate::{AudioHistory, EnvelopeIterator};
use biquad::{coefficients, Biquad, Coefficients, DirectForm1, ToHertz, Type, Q_BUTTERWORTH_F32};
use std::fmt::{Debug, Formatter};

/// Cutoff frequency for the lowpass filter to detect beats.
const CUTOFF_FREQUENCY_HZ: f32 = 70.0;

/// Information about a beat.
pub type BeatInfo = EnvelopeInfo;

/// The audio input source. Each value must be in range `[-1.0..=1.0]`. This
/// abstraction facilitates the libraries goal to prevent needless copying
/// and buffering of data: internally as well as on a higher level.
pub enum AudioInput<'a, I: Iterator<Item = f32>> {
/// The audio input stream only consists of mono samples.
SliceMono(&'a [f32]),
/// The audio input streams consists of interleaved samples following a
/// LRLRLR or RLRLRL scheme. This is typically the case for stereo channel
/// audio. Internally, the audio will be combined to a mono track.
SliceStereo(&'a [f32]),
/// Custom iterator emitting mono samples in f32 format.
Iterator(I),
}

impl<'a, I: Iterator<Item = f32>> Debug for AudioInput<'a, I> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let variant = match self {
AudioInput::SliceMono(_) => "SliceMono(data...)",
AudioInput::SliceStereo(_) => "SliceStereo(data...)",
AudioInput::Iterator(_) => "Iterator(data...)",
};
f.debug_tuple("AudioInput").field(&variant).finish()
}
}

#[derive(Debug)]
pub struct BeatDetector {
lowpass_filter: DirectForm1<f32>,
history: AudioHistory,
}

impl BeatDetector {
pub fn new(sampling_frequency_hz: f32) -> Self {
let lowpass_filter = BeatDetector::create_lowpass_filter(sampling_frequency_hz);
Self {
lowpass_filter,
history: AudioHistory::new(sampling_frequency_hz),
}
}

/// Consumes the latest audio data and returns if the audio history,
/// consisting of previously captured audio and the new data, contains a
/// beat. This function is supposed to be frequently
/// called everytime new audio data from the input source is available so
/// that:
/// - the latency is low
/// - no beats are missed
///
/// From experience, Linux audio input libraries give you a 20-40ms audio
/// buffer every 20-40ms with the latest data. That's a good rule of thumb.
pub fn detect_beat<'a>(
&mut self,
input: AudioInput<'a, impl Iterator<Item = f32>>,
) -> Option<BeatInfo> {
match input {
AudioInput::SliceMono(slice) => {
let iter = slice.iter().map(|&sample| self.lowpass_filter.run(sample));
self.history.update(iter)
}
AudioInput::SliceStereo(slice) => {
let iter = slice
.chunks(2)
.map(|lr| (lr[0] + lr[1]) / 2.0)
.map(|sample| self.lowpass_filter.run(sample));

self.history.update(iter)
}
AudioInput::Iterator(iter) => self.history.update(iter),
}

// TODO prevent detection of same beat
let mut envelope_iter = EnvelopeIterator::new(&self.history, None);
envelope_iter.next()
}

fn create_lowpass_filter(sampling_frequency_hz: f32) -> DirectForm1<f32> {
// Cutoff frequency.
let f0 = CUTOFF_FREQUENCY_HZ.hz();
// Samling frequency.
let fs = sampling_frequency_hz.hz();

let coefficients =
Coefficients::<f32>::from_params(Type::LowPass, fs, f0, Q_BUTTERWORTH_F32).unwrap();
DirectForm1::<f32>::new(coefficients)
}
}

#[cfg(test)]
mod tests {
use crate::BeatDetector;

#[test]
fn is_send_and_sync() {
fn accept<I: Send + Sync>() {};

accept::<BeatDetector>();
}
}
12 changes: 6 additions & 6 deletions src/envelope_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ mod tests {
{
let (samples, header) = test_utils::samples::sample1_single_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

// Taken from waveform in Audacity.
let peak_sample_index = 1430;
Expand All @@ -316,7 +316,7 @@ mod tests {
{
let (samples, header) = test_utils::samples::sample1_double_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

// Taken from waveform in Audacity.
let peak_sample_index = 1634;
Expand All @@ -339,7 +339,7 @@ mod tests {
{
let (samples, header) = test_utils::samples::holiday_single_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

// Taken from waveform in Audacity.
let peak_sample_index = 820;
Expand All @@ -354,7 +354,7 @@ mod tests {
fn find_envelopes_sample1_single_beat() {
let (samples, header) = test_utils::samples::sample1_single_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

let envelopes = EnvelopeIterator::new(&history, None)
.take(1)
Expand All @@ -367,7 +367,7 @@ mod tests {
fn find_envelopes_sample1_double_beat() {
let (samples, header) = test_utils::samples::sample1_double_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

let envelopes = EnvelopeIterator::new(&history, None)
.map(|info| (info.from.index, info.to.index))
Expand All @@ -386,7 +386,7 @@ mod tests {
fn find_envelopes_holiday_single_beat() {
let (samples, header) = test_utils::samples::holiday_single_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

let envelopes = EnvelopeIterator::new(&history, None)
.map(|info| (info.from.index, info.to.index))
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,20 @@ SOFTWARE.
extern crate std;

mod audio_history;
mod audio_input;
mod envelope_iterator;
mod max_min_iterator;
mod root_iterator;

/// PRIVATE. For tests and helper binaries.
#[cfg(test)]
mod test_utils;
mod beat_detector;

pub use audio_history::{AudioHistory, SampleInfo};
pub use audio_input::AudioInput;
use envelope_iterator::EnvelopeIterator;
pub use envelope_iterator::{EnvelopeIterator, EnvelopeInfo};
use max_min_iterator::MaxMinIterator;
use root_iterator::RootIterator;
pub use beat_detector::{BeatDetector, BeatInfo, AudioInput};

#[cfg(test)]
mod tests {
Expand All @@ -108,7 +108,7 @@ mod tests {

fn _print_sample_stats((samples, header): (Vec<f32>, wav::Header)) {
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

let all_peaks = MaxMinIterator::new(&history, None).collect::<Vec<_>>();

Expand Down
4 changes: 2 additions & 2 deletions src/max_min_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ mod tests {
fn find_maxmin_in_holiday_excerpt() {
let (samples, header) = test_utils::samples::holiday_excerpt();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter().copied());

let iter = MaxMinIterator::new(&history, None);
#[rustfmt::skip]
Expand All @@ -102,7 +102,7 @@ mod tests {
fn find_maxmin_in_sample1_single_beat() {
let (samples, header) = test_utils::samples::sample1_single_beat();
let mut history = AudioHistory::new(header.sampling_rate as f32);
history.update(&samples);
history.update(samples.iter());
let iter = MaxMinIterator::new(&history, None);
#[rustfmt::skip]
Expand Down
Loading

0 comments on commit 5cd45cb

Please sign in to comment.