Skip to content

Commit

Permalink
[gpu]: STFT: Added support for win_size less than frame_size.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Dec 12, 2024
1 parent d60e95f commit 214c5ce
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ KERNEL(stft_ref)(
const int frame_step = (int)frame_step_buff[0];
const int window_size = INPUT1_SIZE_X;

const INPUT0_TYPE* restrict signal_for_this_frame = signal + batch*INPUT0_SIZE_X + frame_id*frame_step;
// Handling case where window size is smaller than frame size.
const int start_offset = (frame_size - window_size) / 2;

const INPUT0_TYPE* restrict signal_for_this_frame = signal + batch*INPUT0_SIZE_X + frame_id*frame_step + start_offset;

// FT from def for single freq for given frame:
cfloat freq_val = czero();

// dft_power = 2*PI*(k/N) from dft def.
const float dft_power = 2.0f * M_PI_F * (float)freq_id / (float)frame_size;

for(int i = 0; i < frame_size; ++i) {
for(int i = 0; i < window_size; ++i) {
const float signal_val = (float)signal_for_this_frame[i];
const float window_val = (float)window[i];
const float x_i = signal_val*window_val;
const cfloat e_i = expmi(dft_power*(float)i);
const cfloat e_i = expmi(dft_power*(float)(i+start_offset));
freq_val = cadd(freq_val, crmult(e_i, x_i));
}

Expand Down
215 changes: 214 additions & 1 deletion src/tests/test_utils/unit_test_utils/tests_data/stft_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -1083,4 +1083,217 @@ TEST_DATA(LIST(1, 48),
1.335275,
-1.2002,
-5.702536),
"test_case_5");
"test_case_5");

TEST_DATA(LIST(48),
LIST(7),
LIST(6, 13, 2),
11,
3,
true,
LIST(-0.41676,
-0.05627,
-2.1362,
1.64027,
-1.79344,
-0.84175,
0.50288,
-1.24529,
-1.05795,
-0.90901,
0.55145,
2.29221,
0.04154,
-1.11793,
0.53906,
-0.59616,
-0.01913,
1.175,
-0.74787,
0.00903,
-0.87811,
-0.15643,
0.25657,
-0.98878,
-0.33882,
-0.23618,
-0.63766,
-1.18761,
-1.42122,
-0.1535,
-0.26906,
2.23137,
-2.43477,
0.11273,
0.37044,
1.35963,
0.50186,
-0.84421,
0.00001,
0.54235,
-0.31351,
0.77101,
-1.86809,
1.73118,
1.46768,
-0.33568,
0.61134,
0.04797),
LIST(0.0, 0.25, 0.75, 1.0, 0.75, 0.25, 0.0),
LIST(-1.71092,
0.,
-2.41009,
0.,
2.23022,
0.,
-0.7409,
0.,
0.45297,
0.,
-1.11149,
0.,
-1.14862,
0.,
-2.14551,
0.,
-1.16026,
0.,
-0.65135,
0.,
1.83099,
0.,
-0.1793,
0.,
-0.2968,
0.,
1.47212,
0.71877,
2.17268,
0.79158,
-2.28473,
-0.93586,
0.4625,
0.34192,
-0.56009,
-0.32899,
0.93528,
0.44276,
1.11077,
0.05564,
1.82719,
-0.1221,
0.71587,
1.50743,
1.10802,
-0.41842,
-1.71345,
-0.67438,
0.05781,
0.40969,
0.4558,
-0.24137,
-0.54856,
-1.56669,
-1.47087,
-1.22889,
2.1535,
1.84441,
0.18738,
-0.28908,
0.66134,
0.88008,
-0.66811,
-0.52077,
-1.02705,
-0.15929,
-1.12869,
0.2893,
0.0583,
-1.66476,
-2.16394,
0.18383,
1.42389,
1.02343,
0.32308,
-0.7337,
-0.68826,
0.55139,
-0.91886,
1.85309,
0.52177,
0.97814,
-1.50306,
-2.29021,
-0.76526,
-0.28515,
-0.47423,
-1.4385,
0.63386,
0.43591,
0.90989,
0.38369,
0.51776,
-0.36462,
-0.31809,
0.57129,
2.99689,
0.98808,
-1.06897,
-0.98176,
-0.81284,
0.72147,
0.63521,
-1.1571,
1.74128,
-1.03922,
0.14692,
-0.1082,
0.64531,
1.98433,
0.856,
1.12631,
0.14133,
1.66429,
-0.63884,
-0.57479,
-0.6772,
-0.71798,
-0.19529,
0.22579,
0.09013,
0.66192,
-2.7275,
-2.70068,
0.6808,
0.74142,
0.95724,
-0.28153,
-0.33733,
2.09067,
-0.89051,
-0.04374,
-0.16546,
-0.69762,
-0.12612,
-1.43585,
-0.37017,
-1.74231,
0.00518,
-1.6207,
0.29356,
0.84215,
0.2579,
0.98549,
0.05179,
-0.0244,
0.03393,
-1.30044,
1.1122,
3.98255,
-0.23778,
-0.54982,
-0.43563,
-0.19685,
0.08299,
-2.86001),
"test_case_6");

0 comments on commit 214c5ce

Please sign in to comment.