Skip to content

Commit

Permalink
add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
waltsims committed Jun 3, 2024
1 parent c19e066 commit dccffc9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 15 deletions.
20 changes: 6 additions & 14 deletions kwave/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .checks import is_number
from .data import scale_SI
from .math import find_closest, sinc, next_pow2, norm_var, gaussian
from .math import sinc, next_pow2, norm_var, gaussian
from .matrix import num_dim, num_dim2
from .signals import get_win
from ..kgrid import kWaveGrid
Expand Down Expand Up @@ -162,15 +162,11 @@ def spect(


def extract_amp_phase(
data: np.ndarray, Fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
data: np.ndarray, Fs: float, source_freq: float, dim: str = "auto", fft_padding: int = 3, window: str = "hann"
) -> Tuple[np.ndarray, np.ndarray, float]:
"""
Extract the amplitude and phase information at a specified frequency from a vector or matrix of time series data.
The amplitude and phase are extracted from the frequency spectrum, which is calculated using a windowed and zero
padded FFT. The values are extracted at the frequency closest to source_freq. By default, the time dimension is set
to the highest non-singleton dimension.
Args:
data: Matrix of time signals [s]
Fs: Sampling frequency [Hz]
Expand All @@ -181,10 +177,8 @@ def extract_amp_phase(
Returns:
A tuple of the amplitude, phase and frequency of the extracted signal.
"""

# check for the dim input
# Automatic detection of time dimension
if dim == "auto":
dim = num_dim(data)
if dim == 2 and data.shape[1] == 1:
Expand All @@ -194,10 +188,9 @@ def extract_amp_phase(
# input data
win, coherent_gain = get_win(data.shape[dim], window)
# this list magic in Python comes from the use of ones in MATLAB
# TODO: simplify this
win = np.reshape(win, [1] * (dim - 1) + [len(win)])
win = np.reshape(win, [dim - 1, len(win)])

# apply window to time dimension of input data
# Apply window to time dimension of input data
data = win * data

# compute amplitude and phase spectra
Expand Down Expand Up @@ -701,7 +694,6 @@ def smooth(a: np.ndarray, restore_max: Optional[bool] = False, window_type: Opti

# get the window, taking the absolute value to discard machine precision
# negative values
from .signals import get_win

win, _ = get_win(grid_size, type_=window_type, rotation=DEF_USE_ROTATION, symmetric=window_symmetry)
win = np.abs(win)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
params = { ...
{
randn(1, 1000), ... % 1D data
1000, ... % Sampling frequency
50, ... % Source frequency
'Dim', 1, ... % Optional parameter: dimension
'FFTPadding', 2, ... % Optional parameter: FFT padding
'Window', 'Hanning' % Optional parameter: window type
}, ...
{
randn(1000, 1), ... % 1D data in different dimension
2000, ... % Sampling frequency
100, ... % Source frequency
'Dim', 2, ... % Optional parameter: dimension
'FFTPadding', 3, ... % Optional parameter: FFT padding
'Window', 'Blackman' % Optional parameter: window type
}, ...
{
randn(10, 100), ... % 2D data
500, ... % Sampling frequency
10, ... % Source frequency
'Dim', 2, ... % Optional parameter: dimension
'FFTPadding', 4, ... % Optional parameter: FFT padding
'Window', 'Hamming' % Optional parameter: window type
}, ...
{
randn(50, 50, 50), ... % 3D data
1000, ... % Sampling frequency
250, ... % Source frequency
'Dim', 3, ... % Optional parameter: dimension
'FFTPadding', 5, ... % Optional parameter: FFT padding
'Window', 'Hann' % Optional parameter: window type
}, ...
{
randn(100, 100, 100, 10), ... % 4D data
2000, ... % Sampling frequency
500, ... % Source frequency
'Dim', 4, ... % Optional parameter: dimension
'FFTPadding', 3, ... % Optional parameter: FFT padding
'Window', 'Kaiser' % Optional parameter: window type
}, ...
};

output_file = 'collectedValues/extract_amp_phase.mat';
recorder = utils.TestRecorder(output_file);
for param_idx = 1:length(params)

[amp, phase, freq] = extractAmpPhase(params{param_idx}{:});

recorder.recordVariable('params', params{param_idx});
recorder.recordVariable('amp', amp);
recorder.recordVariable('phase', phase);
recorder.recordVariable('freq', freq);
recorder.increment();

end
recorder.saveRecordsToDisk();
disp('Done.')
19 changes: 18 additions & 1 deletion tests/test_filterutils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
from math import pi
import os
from pathlib import Path

import numpy as np

from kwave.reconstruction.beamform import envelope_detection
from kwave.utils.filters import fwhm
from kwave.utils.filters import extract_amp_phase, fwhm
from tests.matlab_test_data_collectors.python_testers.utils.record_reader import TestRecordReader


def test_extract_amp_phase():
reader = TestRecordReader(os.path.join(Path(__file__).parent, "collectedValues/extract.mat"))

for _ in range(len(reader)):
data, Fs, source_freq, dim, fft_padding, window = reader.expected_value_of("params")

amp, phase, freq = extract_amp_phase(data, Fs, source_freq, dim, fft_padding, window)

assert np.allclose(amp, reader.expected_value_of("amp")), "amp did not match expected lin_ind"
assert np.allclose(phase, reader.expected_value_of("phase")), "phase not match expected is"
assert np.allclose(freq, reader.expected_value_of), "freq did not match expected ks"
reader.increment()


def test_envelope_detection():
Expand Down

0 comments on commit dccffc9

Please sign in to comment.