diff --git a/kwave/utils/filters.py b/kwave/utils/filters.py index 50aab667..19a33a9a 100644 --- a/kwave/utils/filters.py +++ b/kwave/utils/filters.py @@ -8,7 +8,7 @@ from .checks import is_number from .data import scale_SI -from .math import find_closest, sinc, next_pow2, norm_var, gaussian +from .math import sinc, next_pow2, norm_var, gaussian from .matrix import num_dim, num_dim2 from .signals import get_win from ..kgrid import kWaveGrid @@ -162,15 +162,11 @@ def spect( def extract_amp_phase( - data: np.ndarray, Fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning" -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + data: np.ndarray, Fs: float, source_freq: float, dim: str = "auto", fft_padding: int = 3, window: str = "hann" +) -> Tuple[np.ndarray, np.ndarray, float]: """ Extract the amplitude and phase information at a specified frequency from a vector or matrix of time series data. - The amplitude and phase are extracted from the frequency spectrum, which is calculated using a windowed and zero - padded FFT. The values are extracted at the frequency closest to source_freq. By default, the time dimension is set - to the highest non-singleton dimension. - Args: data: Matrix of time signals [s] Fs: Sampling frequency [Hz] @@ -181,10 +177,8 @@ def extract_amp_phase( Returns: A tuple of the amplitude, phase and frequency of the extracted signal. - """ - - # check for the dim input + # Automatic detection of time dimension if dim == "auto": dim = num_dim(data) if dim == 2 and data.shape[1] == 1: @@ -194,10 +188,9 @@ def extract_amp_phase( # input data win, coherent_gain = get_win(data.shape[dim], window) # this list magic in Python comes from the use of ones in MATLAB - # TODO: simplify this - win = np.reshape(win, [1] * (dim - 1) + [len(win)]) + win = np.reshape(win, [dim - 1, len(win)]) - # apply window to time dimension of input data + # Apply window to time dimension of input data data = win * data # compute amplitude and phase spectra @@ -701,7 +694,6 @@ def smooth(a: np.ndarray, restore_max: Optional[bool] = False, window_type: Opti # get the window, taking the absolute value to discard machine precision # negative values - from .signals import get_win win, _ = get_win(grid_size, type_=window_type, rotation=DEF_USE_ROTATION, symmetric=window_symmetry) win = np.abs(win) diff --git a/tests/matlab_test_data_collectors/matlab_collectors/collect_extract_amp_phase.m b/tests/matlab_test_data_collectors/matlab_collectors/collect_extract_amp_phase.m new file mode 100644 index 00000000..401d9bbf --- /dev/null +++ b/tests/matlab_test_data_collectors/matlab_collectors/collect_extract_amp_phase.m @@ -0,0 +1,58 @@ +params = { ... + { + randn(1, 1000), ... % 1D data + 1000, ... % Sampling frequency + 50, ... % Source frequency + 'Dim', 1, ... % Optional parameter: dimension + 'FFTPadding', 2, ... % Optional parameter: FFT padding + 'Window', 'Hanning' % Optional parameter: window type + }, ... + { + randn(1000, 1), ... % 1D data in different dimension + 2000, ... % Sampling frequency + 100, ... % Source frequency + 'Dim', 2, ... % Optional parameter: dimension + 'FFTPadding', 3, ... % Optional parameter: FFT padding + 'Window', 'Blackman' % Optional parameter: window type + }, ... + { + randn(10, 100), ... % 2D data + 500, ... % Sampling frequency + 10, ... % Source frequency + 'Dim', 2, ... % Optional parameter: dimension + 'FFTPadding', 4, ... % Optional parameter: FFT padding + 'Window', 'Hamming' % Optional parameter: window type + }, ... + { + randn(50, 50, 50), ... % 3D data + 1000, ... % Sampling frequency + 250, ... % Source frequency + 'Dim', 3, ... % Optional parameter: dimension + 'FFTPadding', 5, ... % Optional parameter: FFT padding + 'Window', 'Hann' % Optional parameter: window type + }, ... + { + randn(100, 100, 100, 10), ... % 4D data + 2000, ... % Sampling frequency + 500, ... % Source frequency + 'Dim', 4, ... % Optional parameter: dimension + 'FFTPadding', 3, ... % Optional parameter: FFT padding + 'Window', 'Kaiser' % Optional parameter: window type + }, ... +}; + +output_file = 'collectedValues/extract_amp_phase.mat'; +recorder = utils.TestRecorder(output_file); +for param_idx = 1:length(params) + + [amp, phase, freq] = extractAmpPhase(params{param_idx}{:}); + + recorder.recordVariable('params', params{param_idx}); + recorder.recordVariable('amp', amp); + recorder.recordVariable('phase', phase); + recorder.recordVariable('freq', freq); + recorder.increment(); + +end +recorder.saveRecordsToDisk(); +disp('Done.') \ No newline at end of file diff --git a/tests/test_filterutils.py b/tests/test_filterutils.py index 5e45ab7c..29ae9be9 100644 --- a/tests/test_filterutils.py +++ b/tests/test_filterutils.py @@ -1,9 +1,26 @@ from math import pi +import os +from pathlib import Path import numpy as np from kwave.reconstruction.beamform import envelope_detection -from kwave.utils.filters import fwhm +from kwave.utils.filters import extract_amp_phase, fwhm +from tests.matlab_test_data_collectors.python_testers.utils.record_reader import TestRecordReader + + +def test_extract_amp_phase(): + reader = TestRecordReader(os.path.join(Path(__file__).parent, "collectedValues/extract.mat")) + + for _ in range(len(reader)): + data, Fs, source_freq, dim, fft_padding, window = reader.expected_value_of("params") + + amp, phase, freq = extract_amp_phase(data, Fs, source_freq, dim, fft_padding, window) + + assert np.allclose(amp, reader.expected_value_of("amp")), "amp did not match expected lin_ind" + assert np.allclose(phase, reader.expected_value_of("phase")), "phase not match expected is" + assert np.allclose(freq, reader.expected_value_of), "freq did not match expected ks" + reader.increment() def test_envelope_detection():