Skip to content

Commit

Permalink
Add real-time processing for FRCRN_SE_16K
Browse files Browse the repository at this point in the history
Related to modelscope#23

Add real-time processing support for the FRCRN_SE_16K model.

* **clearvoice/models/frcrn_se/frcrn.py**
  - Add a new method `real_time_process` to the `FRCRN_SE_16K` class for real-time processing.
  - Modify the `forward` method to support both offline and real-time processing.
  - Update the `DCCRN` class to handle real-time processing.

* **clearvoice/config/inference/FRCRN_SE_16K.yaml**
  - Change `win_len` to 320 to use 20 ms input windows.
  - Change `win_inc` to 160 to use 20 ms input windows.

* **clearvoice/demo.py**
  - Add a new demo case for real-time processing using the `FRCRN_SE_16K` model.

* **clearvoice/demo_with_more_comments.py**
  - Add a new demo case for real-time processing using the `FRCRN_SE_16K` model.
  • Loading branch information
vishwamartur committed Dec 10, 2024
1 parent 2e8ebfd commit 4f874e2
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 5 deletions.
4 changes: 2 additions & 2 deletions clearvoice/config/inference/FRCRN_SE_16K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ decode_window: 1 #one-pass decoding length
#
# FFT parameters
win_type: 'hanning'
win_len: 640
win_inc: 320
win_len: 320
win_inc: 160
fft_len: 640
14 changes: 14 additions & 0 deletions clearvoice/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,17 @@

#2nd calling method: process video files listed in .scp file, and write outputs to 'path_to_output_videos_tse_scp/'
myClearVoice(input_path='samples/scp/video_samples.scp', online_write=True, output_path='samples/path_to_output_videos_tse_scp')

##-----Demo Six: use FRCRN_SE_16K model for real-time processing -----------------
if False:
myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K'])

##1st calling method: process an input waveform in real-time and return output waveform, then write to output_FRCRN_SE_16K_realtime.wav
output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False)
myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav')

##2nd calling method: process all wav files in 'path_to_input_wavs_realtime/' in real-time and write outputs to 'path_to_output_wavs_realtime'
myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime')

##3rd calling method: process wav files listed in .scp file in real-time, and write outputs to 'path_to_output_wavs_realtime_scp/'
myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp')
28 changes: 28 additions & 0 deletions clearvoice/demo_with_more_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,31 @@
# - online_write (bool): Set to True to enable saving the enhanced output during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/scp/audio_samples.scp', online_write=True, output_path='samples/path_to_output_wavs_scp')

## ---------------- Demo Three: Real-Time Processing -----------------------
if False: # This block demonstrates how to use the FRCRN_SE_16K model for real-time speech enhancement
# Initialize ClearVoice for the task of speech enhancement using the FRCRN_SE_16K model
myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K'])

# 1st calling method:
# Process an input waveform in real-time and return the enhanced output waveform
# - input_path (str): Path to the input noisy audio file (input_realtime.wav)
# - output_wav (dict or ndarray) : The enhanced output waveform
output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False)
# Write the processed waveform to an output file
# - output_path (str): Path to save the enhanced audio file (output_FRCRN_SE_16K_realtime.wav)
myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav')

# 2nd calling method:
# Process and write audio files directly in real-time
# - input_path (str): Path to the directory of input noisy audio files
# - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime')

# 3rd calling method:
# Use an .scp file to specify input audio paths for real-time processing
# - input_path (str): Path to a .scp file listing multiple audio file paths
# - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing
# - output_path (str): Path to the directory to save the enhanced output files
myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp')
61 changes: 58 additions & 3 deletions clearvoice/models/frcrn_se/frcrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,35 @@ def __init__(self, args):
win_type=args.win_type
)

def forward(self, x):
def forward(self, x, real_time=False):
"""
Forward pass of the model.
Args:
x (torch.Tensor): Input tensor representing audio signals.
real_time (bool): Flag to indicate real-time processing.
Returns:
torch.Tensor: Processed output tensor after applying the model.
"""
output = self.model(x)
return output[1][0] # Return estimated waveform
if real_time:
return self.real_time_process(x)
else:
output = self.model(x)
return output[1][0] # Return estimated waveform

def real_time_process(self, x):
"""
Real-time processing method for the FRCRN model.
Args:
x (torch.Tensor): Input tensor representing audio signals.
Returns:
torch.Tensor: Processed output tensor after applying the model in real-time.
"""
output = self.model.real_time_process(x)
return output


class DCCRN(nn.Module):
Expand Down Expand Up @@ -249,3 +266,41 @@ def get_params(self, weight_decay=0.0):
}]
return params

def real_time_process(self, inputs):
"""
Real-time processing method for the DCCRN model.
Args:
inputs (torch.Tensor): Input tensor representing audio signals.
Returns:
torch.Tensor: Processed output tensor after applying the model in real-time.
"""
out_list = []
# Compute the complex spectrogram using STFT
cmp_spec = self.stft(inputs) # [B, D*2, T]
cmp_spec = torch.unsqueeze(cmp_spec, 1) # [B, 1, D*2, T]

# Split into real and imaginary parts
cmp_spec = torch.cat([
cmp_spec[:, :, :self.feat_dim, :], # Real part
cmp_spec[:, :, self.feat_dim:, :], # Imaginary part
], 1) # [B, 2, D, T]

cmp_spec = torch.unsqueeze(cmp_spec, 4) # [B, 2, D, T, 1]
cmp_spec = torch.transpose(cmp_spec, 1, 4) # [B, 1, D, T, 2]

# Pass through the UNet to estimate masks
unet1_out = self.unet(cmp_spec) # First UNet output
cmp_mask1 = torch.tanh(unet1_out) # First mask

unet2_out = self.unet2(unet1_out) # Second UNet output
cmp_mask2 = torch.tanh(unet2_out) # Second mask
cmp_mask2 = cmp_mask2 + cmp_mask1 # Combine masks

# Apply the estimated mask to the complex spectrogram
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
return out_list

0 comments on commit 4f874e2

Please sign in to comment.