From 8f0461690d9a885ffae5e0a7117aff460ab27ab7 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Tue, 7 Jan 2025 00:34:44 +0530 Subject: [PATCH] Specify window_length dtype requirement in tf.keras.ops.istft in math.py (#20728) The `window_length` parameter in `tf.keras.ops.istft` requires `tf.int32` dtype, but this isn't documented. This can cause unexpected `ValueError` when using `tf.int64` and `tf.int16` Here is the Example case: ``` import tensorflow as tf input_dict = { 'stfts': tf.constant([[-0.87817144+1.14583987j, -0.32066484+0.25565411j]], dtype=tf.complex128), 'frame_length': tf.constant(256, dtype=tf.int16), 'frame_step': tf.constant(5120,dtype=tf.int64) } result = tf.signal.inverse_stft(**input_dict) print(result) ``` The code throws the following error: ``` ValueError: window_length: Tensor conversion requested dtype int32 for Tensor with dtype int64 ``` --- keras/src/ops/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index 7caaa41f6283..6cedef62cee4 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -888,7 +888,7 @@ def istft( sequence_length: An integer representing the sequence length. sequence_stride: An integer representing the sequence hop size. fft_length: An integer representing the size of the FFT that produced - `stft`. + `stft`. Should be of type `int32`. length: An integer representing the output is clipped to exactly length. If not specified, no padding or clipping take place. Defaults to `None`.