Skip to content

Commit

Permalink
Specify window_length dtype requirement in tf.keras.ops.istft in math…
Browse files Browse the repository at this point in the history
….py (#20728)

The `window_length` parameter in `tf.keras.ops.istft` requires `tf.int32` dtype, but this isn't documented. This can cause unexpected `ValueError` when using `tf.int64` and `tf.int16`

Here is the Example case:
```
import tensorflow as tf

input_dict = {
    'stfts': tf.constant([[-0.87817144+1.14583987j, -0.32066484+0.25565411j]], dtype=tf.complex128),
    'frame_length': tf.constant(256, dtype=tf.int16),
    'frame_step': tf.constant(5120,dtype=tf.int64)
}
result = tf.signal.inverse_stft(**input_dict)
print(result)
```
The code throws the following error:
```
ValueError: window_length: Tensor conversion requested dtype int32 for Tensor with dtype int64
```
  • Loading branch information
LakshmiKalaKadali authored Jan 6, 2025
1 parent 1adaaec commit 8f04616
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def istft(
sequence_length: An integer representing the sequence length.
sequence_stride: An integer representing the sequence hop size.
fft_length: An integer representing the size of the FFT that produced
`stft`.
`stft`. Should be of type `int32`.
length: An integer representing the output is clipped to exactly length.
If not specified, no padding or clipping take place. Defaults to
`None`.
Expand Down

0 comments on commit 8f04616

Please sign in to comment.