diff --git a/README.md b/README.md index afb05573..690e93e8 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ pre-commit install # for clean and tidy code ``` ## 🔥 News -* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347). +* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347) and [#352](https://github.com/wenet-e2e/wespeaker/pull/352). * 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344). * 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320). * 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291). diff --git a/examples/voxceleb/v2/conf/redimnet.yaml b/examples/voxceleb/v2/conf/redimnet.yaml new file mode 100644 index 00000000..6e8fb295 --- /dev/null +++ b/examples/voxceleb/v2/conf/redimnet.yaml @@ -0,0 +1,86 @@ +exp_dir: exp/RedimnetB2-emb192-fbank72-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch120 +gpus: "[0,1]" +num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 120 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 256 + num_workers: 4 + pin_memory: false + prefetch_factor: 4 + drop_last: true + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 800 + resample_rate: 16000 + speed_perturb: True + num_frms: 200 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 72 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ReDimNetB2 +model_init: null +model_args: + feat_dim: 72 + embed_dim: 192 + pooling_func: "ASTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False + + +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.0 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: True + increase_type: "exp" # exp, linear + update_margin: true + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 2.0e-05 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.1 + final_lr: 0.00005 + warm_up_epoch: 6 + warm_from_zero: True + diff --git a/examples/voxconverse/README.md b/examples/voxconverse/README.md index 3fb5772b..85af1c2d 100644 --- a/examples/voxconverse/README.md +++ b/examples/voxconverse/README.md @@ -1,3 +1,7 @@ This is a **WeSpeaker** speaker diarization recipe on the Voxconverse 2020 dataset. It focused on a ``in the wild`` scenario, which was collected from YouTube videos with a semi-automatic pipeline and released for the diarization track in VoxSRC 2020 Challenge. See https://www.robots.ox.ac.uk/~vgg/data/voxconverse/ for more detailed information. Two recipes are provided, including **v1** and **v2**. Their only difference is that in **v2**, we split the Fbank extraction, embedding extraction and clustering modules to different stages. We recommend newcomers to follow the **v2** recipe and run it stage by stage. + +🔥 UPDATE 2024.08.20: +* silero-vad v5.1 is used in place of v3.1 +* umap dimensionality reduction + hdbscan clustering is also supported in v2 diff --git a/examples/voxconverse/v1/README.md b/examples/voxconverse/v1/README.md index 13ec2aef..0c51b555 100644 --- a/examples/voxconverse/v1/README.md +++ b/examples/voxconverse/v1/README.md @@ -1,11 +1,13 @@ ## Overview * We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported. -* Dataset: voxconverse_dev that consists of 216 utterances -* Speaker model: ResNet34 model pretrained by wespeaker +* Dataset: Voxconverse2020 (dev: 216 utts) +* Speaker model: ResNet34 model pretrained by WeSpeaker * Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) * [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) -* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad) +* Speaker activity detection model: + * oracle SAD (from ground truth annotation) + * system SAD (VAD model pretrained by [silero-vad](https://github.com/snakers4/silero-vad), v3.1 is deprecated now) * Clustering method: spectral clustering * Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%) @@ -15,8 +17,8 @@ | system | MISS | FA | SC | DER | |:---|:---:|:---:|:---:|:---:| - | This repo (with oracle SAD) | 2.3 | 0.0 | 1.9 | 4.2 | - | This repo (with system SAD) | 3.7 | 0.8 | 2.0 | 6.5 | + | Ours (oracle SAD + spectral clustering) | 2.3 | 0.0 | 1.9 | 4.2 | + | Ours (silero-vad v3.1 + spectral clustering) | 3.7 | 0.8 | 2.0 | 6.5 | | DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 | | DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 | | (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 | diff --git a/examples/voxconverse/v1/run.sh b/examples/voxconverse/v1/run.sh index b852ec80..f60bd3ad 100755 --- a/examples/voxconverse/v1/run.sh +++ b/examples/voxconverse/v1/run.sh @@ -29,8 +29,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools # [2] Download voice activity detection model pretrained by Silero Team - wget -c https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.zip -O external_tools/silero-vad-v3.1.zip - unzip -o external_tools/silero-vad-v3.1.zip -d external_tools + #wget -c https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.zip -O external_tools/silero-vad-v3.1.zip + #unzip -o external_tools/silero-vad-v3.1.zip -d external_tools # [3] Download ResNet34 speaker model pretrained by WeSpeaker Team mkdir -p pretrained_models @@ -79,7 +79,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [[ "x${sad_type}" == "xsystem" ]]; then # System SAD: applying 'silero' VAD python3 wespeaker/diar/make_system_sad.py \ - --repo-path external_tools/silero-vad-3.1 \ --scp data/dev/wav.scp \ --min-duration $min_duration > data/dev/system_sad fi diff --git a/examples/voxconverse/v2/README.md b/examples/voxconverse/v2/README.md index 02f41fa1..7a1d339a 100644 --- a/examples/voxconverse/v2/README.md +++ b/examples/voxconverse/v2/README.md @@ -1,12 +1,16 @@ ## Overview * We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported. -* Dataset: voxconverse_dev that consists of 216 utterances -* Speaker model: ResNet34 model pretrained by wespeaker +* Dataset: Voxconverse2020 (dev: 216 utts, test: 232 utts) +* Speaker model: ResNet34 model pretrained by WeSpeaker * Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) * [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) -* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad) -* Clustering method: spectral clustering +* Speaker activity detection model: + * oracle SAD (from ground truth annotation) + * system SAD (VAD model pretrained by [silero-vad](https://github.com/snakers4/silero-vad), v3.1 => v5.1) +* Clustering method: + * spectral clustering + * umap dimensionality reduction + hdbscan clustering * Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%) ## Results @@ -15,8 +19,11 @@ | system | MISS | FA | SC | DER | |:---|:---:|:---:|:---:|:---:| - | This repo (with oracle SAD) | 2.3 | 0.0 | 2.1 | 4.4 | - | This repo (with system SAD) | 3.7 | 0.8 | 2.2 | 6.8 | + | Ours (oracle SAD + spectral clustering) | 2.3 | 0.0 | 2.1 | 4.4 | + | Ours (oracle SAD + umap clustering) | 2.3 | 0.0 | 1.3 | 3.6 | + | Ours (silero-vad v3.1 + spectral clustering) | 3.7 | 0.8 | 2.2 | 6.7 | + | Ours (silero-vad v5.1 + spectral clustering) | 3.4 | 0.6 | 2.3 | 6.3 | + | Ours (silero-vad v5.1 + umap clustering) | 3.4 | 0.6 | 1.4 | 5.4 | | DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 | | DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 | | (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 | @@ -27,7 +34,11 @@ | system | MISS | FA | SC | DER | |:---|:---:|:---:|:---:|:---:| - | This repo (with system SAD) | 4.0 | 2.4 | 3.4 | 9.8 | + | Ours (oracle SAD + spectral clustering) | 1.6 | 0.0 | 3.3 | 4.9 | + | Ours (oracle SAD + umap clustering) | 1.6 | 0.0 | 1.9 | 3.5 | + | Ours (silero-vad v3.1 + spectral clustering) | 4.0 | 2.4 | 3.4 | 9.8 | + | Ours (silero-vad v5.1 + spectral clustering) | 3.8 | 1.7 | 3.3 | 8.8 | + | Ours (silero-vad v5.1 + umap clustering) | 3.8 | 1.7 | 1.8 | 7.3 | [^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf diff --git a/examples/voxconverse/v2/run.sh b/examples/voxconverse/v2/run.sh index 8e786297..6c83171c 100755 --- a/examples/voxconverse/v2/run.sh +++ b/examples/voxconverse/v2/run.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright (c) 2022-2023 Xu Xiang # 2022 Zhengyang Chen (chenzhengyang117@gmail.com) +# 2024 Hongji Wang (jijijiang77@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +19,9 @@ stage=-1 stop_stage=-1 -sad_type="oracle" -partition="dev" +sad_type="oracle" # oracle/system +partition="dev" # dev/test +cluster_type="spectral" # spectral/umap # do cmn on the sub-segment or on the vad segment subseg_cmn=true @@ -36,11 +38,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then wget -c https://github.com/usnistgov/SCTK/archive/refs/tags/v2.4.12.zip -O external_tools/SCTK-v2.4.12.zip unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools - # [2] Download voice activity detection model pretrained by Silero Team - wget -c https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.zip -O external_tools/silero-vad-v3.1.zip - unzip -o external_tools/silero-vad-v3.1.zip -d external_tools - - # [3] Download ResNet34 speaker model pretrained by WeSpeaker Team + # [2] Download ResNet34 speaker model pretrained by WeSpeaker Team mkdir -p pretrained_models wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O pretrained_models/voxceleb_resnet34_LM.onnx @@ -101,7 +99,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [[ "x${sad_type}" == "xsystem" ]]; then # System SAD: applying 'silero' VAD python3 wespeaker/diar/make_system_sad.py \ - --repo-path external_tools/silero-vad-3.1 \ --scp data/${partition}/wav.scp \ --min-duration $min_duration > data/${partition}/system_sad fi @@ -144,24 +141,24 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then fi -# Applying spectral clustering algorithm +# Applying spectral or ump+hdbscan clustering algorithm if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then - [ -f "exp/spectral_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/spectral_cluster/${partition}_${sad_type}_sad_labels + [ -f "exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_labels - echo "Doing spectral clustering and store the result in exp/spectral_cluster/${partition}_${sad_type}_sad_labels" + echo "Doing ${cluster_type} clustering and store the result in exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_labels" echo "..." - python3 wespeaker/diar/spectral_clusterer.py \ + python3 wespeaker/diar/${cluster_type}_clusterer.py \ --scp exp/${partition}_${sad_type}_sad_embedding/emb.scp \ - --output exp/spectral_cluster/${partition}_${sad_type}_sad_labels + --output exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_labels fi # Convert labels to RTTMs if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then python3 wespeaker/diar/make_rttm.py \ - --labels exp/spectral_cluster/${partition}_${sad_type}_sad_labels \ - --channel 1 > exp/spectral_cluster/${partition}_${sad_type}_sad_rttm + --labels exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_labels \ + --channel 1 > exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_rttm fi @@ -173,18 +170,18 @@ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ -c 0.25 \ -r <(cat ${ref_dir}/${partition}/*.rttm) \ - -s exp/spectral_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/spectral_cluster/${partition}_${sad_type}_sad_res + -s exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_res if [ ${get_each_file_res} -eq 1 ];then - single_file_res_dir=exp/spectral_cluster/${partition}_${sad_type}_single_file_res + single_file_res_dir=exp/${cluster_type}_cluster/${partition}_${sad_type}_single_file_res mkdir -p $single_file_res_dir echo -e "\nGet the DER results for each file and the results will be stored underd ${single_file_res_dir}\n..." - awk '{print $2}' exp/spectral_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do + awk '{print $2}' exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ -c 0.25 \ -r <(cat ${ref_dir}/${partition}/${file_name}.rttm) \ - -s <(grep "${file_name}" exp/spectral_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res + -s <(grep "${file_name}" exp/${cluster_type}_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res done echo "Done!" fi diff --git a/examples/voxconverse/v3/README.md b/examples/voxconverse/v3/README.md deleted file mode 100644 index 5b333714..00000000 --- a/examples/voxconverse/v3/README.md +++ /dev/null @@ -1,34 +0,0 @@ -## Overview - -* We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported. -* Dataset: voxconverse_dev that consists of 216 utterances -* Speaker model: ResNet34 model pretrained by wespeaker - * Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) - * [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx) -* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad) -* Clustering method: umap dimensionality reduction + hdbscan clustering -* Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%) - -## Results - -* Dev set - - | system | MISS | FA | SC | DER | - |:---|:---:|:---:|:---:|:---:| - | This repo (with oracle SAD) | 2.3 | 0.0 | 1.3 | 3.6 | - | This repo (with system SAD) | 3.4 | 0.6 | 1.4 | 5.4 | - | DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 | - | DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 | - | (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 | - | (AVSE ASD only) [^1] | 2.0 | 5.9 | 4.6 | 12.4 | - | (proposed) [^1] | 2.4 | 2.3 | 3.0 | 7.7 | - -* Test set - - | system | MISS | FA | SC | DER | - |:---|:---:|:---:|:---:|:---:| - | This repo (with oracle SAD) | 1.6 | 0.0 | 1.9 | 3.5 | - | This repo (with system SAD) | 3.8 | 1.7 | 1.8 | 7.4 | - - -[^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf diff --git a/examples/voxconverse/v3/local b/examples/voxconverse/v3/local deleted file mode 120000 index 8b1d5f97..00000000 --- a/examples/voxconverse/v3/local +++ /dev/null @@ -1 +0,0 @@ -../v2/local \ No newline at end of file diff --git a/examples/voxconverse/v3/path.sh b/examples/voxconverse/v3/path.sh deleted file mode 120000 index b6a713c8..00000000 --- a/examples/voxconverse/v3/path.sh +++ /dev/null @@ -1 +0,0 @@ -../v2/path.sh \ No newline at end of file diff --git a/examples/voxconverse/v3/run.sh b/examples/voxconverse/v3/run.sh deleted file mode 100755 index f53cfab5..00000000 --- a/examples/voxconverse/v3/run.sh +++ /dev/null @@ -1,186 +0,0 @@ -#!/bin/bash -# Copyright (c) 2022-2023 Xu Xiang -# 2022 Zhengyang Chen (chenzhengyang117@gmail.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -. ./path.sh || exit 1 - -stage=-1 -stop_stage=-1 -sad_type="oracle" -partition="dev" - -# do cmn on the sub-segment or on the vad segment -subseg_cmn=true -# whether print the evaluation result for each file -get_each_file_res=1 - -. tools/parse_options.sh - -# Prerequisite -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - mkdir -p external_tools - - # [1] Download evaluation toolkit - wget -c https://github.com/usnistgov/SCTK/archive/refs/tags/v2.4.12.zip -O external_tools/SCTK-v2.4.12.zip - unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools - - # [3] Download ResNet34 speaker model pretrained by WeSpeaker Team - mkdir -p pretrained_models - - wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O pretrained_models/voxceleb_resnet34_LM.onnx -fi - - -# Download VoxConverse dev/test audios and the corresponding annotations -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - mkdir -p data - - # Download annotations for dev and test sets (version 0.0.3) - wget -c https://github.com/joonson/voxconverse/archive/refs/heads/master.zip -O data/voxconverse_master.zip - unzip -o data/voxconverse_master.zip -d data - - # Download annotations from VoxSRC-23 validation toolkit (looks like version 0.0.2) - # cd data && git clone https://github.com/JaesungHuh/VoxSRC2023.git --recursive && cd - - - # Download dev audios - mkdir -p data/dev - - #wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip - # The above url may not be reachable, you can try the link below. - # This url is from https://github.com/joonson/voxconverse/blob/master/README.md - wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip - unzip -o data/voxconverse_dev_wav.zip -d data/dev - - # Create wav.scp for dev audios - ls `pwd`/data/dev/audio/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/dev/wav.scp - - # Test audios - mkdir -p data/test - - #wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip - # The above url may not be reachable, you can try the link below. - # This url is from https://github.com/joonson/voxconverse/blob/master/README.md - wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip - unzip -o data/voxconverse_test_wav.zip -d data/test - - # Create wav.scp for test audios - ls `pwd`/data/test/voxconverse_test_wav/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/test/wav.scp -fi - - -# Voice activity detection -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # Set VAD min duration - min_duration=0.255 - - if [[ "x${sad_type}" == "xoracle" ]]; then - # Oracle SAD: handling overlapping or too short regions in ground truth RTTM - while read -r utt wav_path; do - python3 wespeaker/diar/make_oracle_sad.py \ - --rttm data/voxconverse-master/${partition}/${utt}.rttm \ - --min-duration $min_duration - done < data/${partition}/wav.scp > data/${partition}/oracle_sad - fi - - if [[ "x${sad_type}" == "xsystem" ]]; then - # System SAD: applying 'silero' VAD - python3 wespeaker/diar/make_system_sad.py \ - --scp data/${partition}/wav.scp \ - --min-duration $min_duration > data/${partition}/system_sad - fi -fi - - -# Extract fbank features -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - - [ -d "exp/${sad_type}_sad_fbank" ] && rm -r exp/${sad_type}_sad_fbank - - echo "Make Fbank features and store it under exp/${sad_type}_sad_fbank" - echo "..." - bash local/make_fbank.sh \ - --scp data/${partition}/wav.scp \ - --segments data/${partition}/${sad_type}_sad \ - --store_dir exp/${partition}_${sad_type}_sad_fbank \ - --subseg_cmn ${subseg_cmn} \ - --nj 24 -fi - -# Extract embeddings -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - - [ -d "exp/${sad_type}_sad_embedding" ] && rm -r exp/${sad_type}_sad_embedding - - echo "Extract embeddings and store it under exp/${sad_type}_sad_embedding" - echo "..." - bash local/extract_emb.sh \ - --scp exp/${partition}_${sad_type}_sad_fbank/fbank.scp \ - --pretrained_model pretrained_models/voxceleb_resnet34_LM.onnx \ - --device cuda \ - --store_dir exp/${partition}_${sad_type}_sad_embedding \ - --batch_size 96 \ - --frame_shift 10 \ - --window_secs 1.5 \ - --period_secs 0.75 \ - --subseg_cmn ${subseg_cmn} \ - --nj 1 -fi - - -# Applying umap clustering algorithm -if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then - - [ -f "exp/umap_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/umap_cluster/${partition}_${sad_type}_sad_labels - - echo "Doing umap clustering and store the result in exp/umap_cluster/${partition}_${sad_type}_sad_labels" - echo "..." - python3 wespeaker/diar/umap_clusterer.py \ - --scp exp/${partition}_${sad_type}_sad_embedding/emb.scp \ - --output exp/umap_cluster/${partition}_${sad_type}_sad_labels -fi - - -# Convert labels to RTTMs -if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then - python3 wespeaker/diar/make_rttm.py \ - --labels exp/umap_cluster/${partition}_${sad_type}_sad_labels \ - --channel 1 > exp/umap_cluster/${partition}_${sad_type}_sad_rttm -fi - - -# Evaluate the result -if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then - ref_dir=data/voxconverse-master/ - #ref_dir=data/VoxSRC2023/voxconverse/ - echo -e "Get the DER results\n..." - perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ - -c 0.25 \ - -r <(cat ${ref_dir}/${partition}/*.rttm) \ - -s exp/umap_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/umap_cluster/${partition}_${sad_type}_sad_res - - if [ ${get_each_file_res} -eq 1 ];then - single_file_res_dir=exp/umap_cluster/${partition}_${sad_type}_single_file_res - mkdir -p $single_file_res_dir - echo -e "\nGet the DER results for each file and the results will be stored underd ${single_file_res_dir}\n..." - - awk '{print $2}' exp/umap_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do - perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ - -c 0.25 \ - -r <(cat ${ref_dir}/${partition}/${file_name}.rttm) \ - -s <(grep "${file_name}" exp/umap_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res - done - echo "Done!" - fi -fi diff --git a/examples/voxconverse/v3/tools b/examples/voxconverse/v3/tools deleted file mode 120000 index c92f4172..00000000 --- a/examples/voxconverse/v3/tools +++ /dev/null @@ -1 +0,0 @@ -../../../tools \ No newline at end of file diff --git a/examples/voxconverse/v3/wespeaker b/examples/voxconverse/v3/wespeaker deleted file mode 120000 index 900c560b..00000000 --- a/examples/voxconverse/v3/wespeaker +++ /dev/null @@ -1 +0,0 @@ -../../../wespeaker \ No newline at end of file diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py new file mode 100644 index 00000000..ecd7db17 --- /dev/null +++ b/wespeaker/models/redimnet.py @@ -0,0 +1,1032 @@ +# Copyright (c) 2024 https://github.com/IDRnD/ReDimNet +# 2024 Shuai Wang (wsstriving@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redimnet in pytorch. + +Reference: +Paper: "Reshape Dimensions Network for Speaker Recognition" +Repo: https://github.com/IDRnD/ReDimNet + +Cite: +@misc{yakovlev2024reshapedimensionsnetworkspeaker, + title={Reshape Dimensions Network for Speaker Recognition}, + author={Ivan Yakovlev and Rostislav Makarov and Andrei Balykin + and Pavel Malov and Anton Okhotnikov and Nikita Torgashov}, + year={2024}, + eprint={2407.18223}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + url={https://arxiv.org/abs/2407.18223}, +} +""" +import math + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import wespeaker.models.pooling_layers as pooling_layers + + +MaxPoolNd = {1: nn.MaxPool1d, 2: nn.MaxPool2d} +ConvNd = {1: nn.Conv1d, 2: nn.Conv2d} +BatchNormNd = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d} + + +class to1d(nn.Module): + def forward(self, x): + size = x.size() + bs, c, f, t = tuple(size) + return x.permute((0, 2, 1, 3)).reshape((bs, c * f, t)) + + +class NewGELUActivation(nn.Module): + def forward(self, input): + return ( + 0.5 + * input + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) + * (input + 0.044715 * torch.pow(input, 3.0)) + ) + ) + ) + + +class LayerNorm(nn.Module): + """ + LayerNorm that supports two data formats: channels_last or channels_first. + The ordering of the dimensions in the inputs. + channels_last corresponds to inputs with shape (batch_size, T, channels) + while channels_first corresponds to shape (batch_size, channels, T). + """ + + def __init__(self, C, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(C)) + self.bias = nn.Parameter(torch.zeros(C)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.C = (C,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.C, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + + w = self.weight + b = self.bias + for _ in range(x.ndim - 2): + w = w.unsqueeze(-1) + b = b.unsqueeze(-1) + x = w * x + b + return x + + def extra_repr(self) -> str: + return ", ".join( + [ + f"{k}={v}" + for k, v in { + "C": self.C, + "data_format": self.data_format, + "eps": self.eps, + }.items() + ] + ) + + +class GRU(nn.Module): + def __init__(self, *args, **kwargs): + super(GRU, self).__init__() + self.gru = nn.GRU(*args, **kwargs) + + def forward(self, x): + # x : (bs,C,T) + return self.gru(x.permute((0, 2, 1)))[0].permute((0, 2, 1)) + + +class PosEncConv(nn.Module): + def __init__(self, C, ks, groups=None): + super().__init__() + assert ks % 2 == 1 + self.conv = nn.Conv1d( + C, C, ks, padding=ks // 2, groups=C if groups is None else groups + ) + self.norm = LayerNorm(C, eps=1e-6, data_format="channels_first") + + def forward(self, x): + return x + self.norm(self.conv(x)) + + +class ConvNeXtLikeBlock(nn.Module): + def __init__( + self, + C, + dim=2, + kernel_sizes=((3, 3),), + group_divisor=1, + padding="same", + ): + super().__init__() + self.dwconvs = nn.ModuleList( + modules=[ + ConvNd[dim]( + C, + C, + kernel_size=ks, + padding=padding, + groups=C // group_divisor if group_divisor is not None else 1, + ) + for ks in kernel_sizes + ] + ) + self.norm = BatchNormNd[dim](C * len(kernel_sizes)) + self.gelu = nn.GELU() + self.pwconv1 = ConvNd[dim](C * len(kernel_sizes), C, 1) + + def forward(self, x): + skip = x + x = torch.cat([dwconv(x) for dwconv in self.dwconvs], dim=1) + x = self.gelu(self.norm(x)) + x = self.pwconv1(x) + x = skip + x + return x + + +class ConvBlock2d(nn.Module): + def __init__(self, c, f, block_type="convnext_like", group_divisor=1): + super().__init__() + if block_type == "convnext_like": + self.conv_block = ConvNeXtLikeBlock( + c, + dim=2, + kernel_sizes=[(3, 3)], + group_divisor=group_divisor, + padding="same", + ) + elif block_type == "basic_resnet": + self.conv_block = ResBasicBlock( + c, + c, + f, + stride=1, + se_channels=min(64, max(c, 32)), + group_divisor=group_divisor, + use_fwSE=False, + ) + elif block_type == "basic_resnet_fwse": + self.conv_block = ResBasicBlock( + c, + c, + f, + stride=1, + se_channels=min(64, max(c, 32)), + group_divisor=group_divisor, + use_fwSE=True, + ) + else: + raise NotImplementedError() + + def forward(self, x): + return self.conv_block(x) + + +class MultiHeadAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got " + f"`embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len, bsz): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attn_weights = F.softmax(attn_weights, dim=-1) + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) + # rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + n_state, + n_mlp, + n_head, + channel_last=False, + act_do=0.0, + att_do=0.0, + hid_do=0.0, + ln_eps=1e-6, + ): + + hidden_size = n_state + num_attention_heads = n_head + intermediate_size = n_mlp + activation_dropout = act_do + attention_dropout = att_do + hidden_dropout = hid_do + layer_norm_eps = ln_eps + + super().__init__() + self.channel_last = channel_last + self.attention = MultiHeadAttention( + embed_dim=hidden_size, + num_heads=num_attention_heads, + dropout=attention_dropout, + ) + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.feed_forward = FeedForward( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_dropout=activation_dropout, + hidden_dropout=hidden_dropout, + ) + self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states): + if not self.channel_last: + hidden_states = hidden_states.permute(0, 2, 1) + attn_residual = hidden_states + hidden_states = self.attention(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = hidden_states + if not self.channel_last: + outputs = outputs.permute(0, 2, 1) + return outputs + + +class FeedForward(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + activation_dropout=0.0, + hidden_dropout=0.0, + ): + super().__init__() + self.intermediate_dropout = nn.Dropout(activation_dropout) + self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = NewGELUActivation() + self.output_dense = nn.Linear(intermediate_size, hidden_size) + self.output_dropout = nn.Dropout(hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class BasicBlock(nn.Module): + """ + Key difference with the BasicBlock in resnet.py: + 1. If use group convolution, conv1 have same number of input/output channels + 2. No stride to downsample + """ + + def __init__( + self, + in_planes, + planes, + stride=1, + group_divisor=4, + ): + super().__init__() + self.conv1 = nn.Conv2d( + in_planes, + in_planes if group_divisor is not None else planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=in_planes // group_divisor if group_divisor is not None else 1, + ) + + # If using group convolution, add point-wise conv to reshape + if group_divisor is not None: + self.conv1pw = nn.Conv2d(in_planes, planes, 1) + else: + self.conv1pw = nn.Identity() + + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + padding=1, + bias=False, + groups=planes // group_divisor if group_divisor is not None else 1, + ) + + # If using group convolution, add point-wise conv to reshape + if group_divisor is not None: + self.conv2pw = nn.Conv2d(planes, planes, 1) + else: + self.conv2pw = nn.Identity() + + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if planes != in_planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + else: + self.shortcut = nn.Identity() + + def forward(self, x): + residual = x + + out = self.conv1pw(self.conv1(x)) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2pw(self.conv2(out)) + out = self.bn2(out) + + out += self.shortcut(residual) + out = self.relu(out) + return out + + +class fwSEBlock(nn.Module): + """ + Squeeze-and-Excitation block + link: https://arxiv.org/pdf/1709.01507.pdf + PyTorch implementation + """ + + def __init__(self, num_freq, num_feats=64): + super(fwSEBlock, self).__init__() + self.squeeze = nn.Linear(num_freq, num_feats) + self.exitation = nn.Linear(num_feats, num_freq) + + self.activation = nn.ReLU() # Assuming ReLU, modify as needed + + def forward(self, inputs): + # [bs, C, F, T] + x = torch.mean(inputs, dim=[1, 3]) + x = self.squeeze(x) + x = self.activation(x) + x = self.exitation(x) + x = torch.sigmoid(x) + # Reshape and apply excitation + x = x[:, None, :, None] + x = inputs * x + return x + + +class ResBasicBlock(nn.Module): + def __init__( + self, + in_planes, + planes, + num_freq, + stride=1, + se_channels=64, + group_divisor=4, + use_fwSE=False, + ): + super().__init__() + self.conv1 = nn.Conv2d( + in_planes, + in_planes if group_divisor is not None else planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=in_planes // group_divisor if group_divisor is not None else 1, + ) + if group_divisor is not None: + self.conv1pw = nn.Conv2d(in_planes, planes, 1) + else: + self.conv1pw = nn.Identity() + + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + padding=1, + bias=False, + groups=planes // group_divisor if group_divisor is not None else 1, + ) + + if group_divisor is not None: + self.conv2pw = nn.Conv2d(planes, planes, 1) + else: + self.conv2pw = nn.Identity() + + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if use_fwSE: + self.se = fwSEBlock(num_freq, se_channels) + else: + self.se = nn.Identity() + + if planes != in_planes: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + else: + self.downsample = nn.Identity() + + def forward(self, x): + residual = x + + out = self.conv1pw(self.conv1(x)) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2pw(self.conv2(out)) + out = self.bn2(out) + out = self.se(out) + + out += self.downsample(residual) + out = self.relu(out) + return out + + +class TimeContextBlock1d(nn.Module): + """ """ + + def __init__( + self, + C, + hC, + pos_ker_sz=59, + block_type="att", + ): + super().__init__() + assert pos_ker_sz + + self.red_dim_conv = nn.Sequential( + nn.Conv1d(C, hC, 1), LayerNorm(hC, eps=1e-6, data_format="channels_first") + ) + + if block_type == "fc": + self.tcm = nn.Sequential( + nn.Conv1d(hC, hC * 2, 1), + LayerNorm(hC * 2, eps=1e-6, data_format="channels_first"), + nn.GELU(), + nn.Conv1d(hC * 2, hC, 1), + ) + elif block_type == "gru": + # Just GRU + self.tcm = nn.Sequential( + GRU( + input_size=hC, + hidden_size=hC, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=True, + ), + nn.Conv1d(2 * hC, hC, 1), + ) + elif block_type == "att": + # Basic Transformer self-attention encoder block + self.tcm = nn.Sequential( + PosEncConv(hC, ks=pos_ker_sz, groups=hC), + TransformerEncoderLayer(n_state=hC, n_mlp=hC * 2, n_head=4), + ) + elif block_type == "conv+att": + # Basic Transformer self-attention encoder block + self.tcm = nn.Sequential( + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[7], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[19], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[31], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[59], group_divisor=1, padding="same" + ), + TransformerEncoderLayer(n_state=hC, n_mlp=hC, n_head=4), + ) + else: + raise NotImplementedError() + + self.exp_dim_conv = nn.Conv1d(hC, C, 1) + + def forward(self, x): + skip = x + x = self.red_dim_conv(x) + x = self.tcm(x) + x = self.exp_dim_conv(x) + return skip + x + + +class ReDimNetBone(nn.Module): + def __init__( + self, + F=72, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=( + # stride, num_blocks, conv_exp, kernel_size, att_block_red + (1, 2, 1, [(3, 3)], None), # 16 + (2, 3, 1, [(3, 3)], None), # 32 + # 64, (72*12 // 8) = 108 - channels in attention block + (3, 4, 1, [(3, 3)], 8), + (2, 5, 1, [(3, 3)], 8), # 128 + (1, 5, 1, [(7, 1)], 8), # 128 # TDNN - time context + (2, 3, 1, [(3, 3)], 8), # 256 + ), + group_divisor=1, + out_channels=512, + ): + super().__init__() + self.F = F + self.C = C + + self.block_1d_type = block_1d_type + self.block_2d_type = block_2d_type + + self.stages_setup = stages_setup + self.build(stages_setup, group_divisor, out_channels) + + def build(self, stages_setup, group_divisor, out_channels): + self.num_stages = len(stages_setup) + + cur_c = self.C + cur_f = self.F + # Weighting the inputs + # TODO: ask authors about the impact of this pre-weighting + self.inputs_weights = torch.nn.ParameterList( + [nn.Parameter(torch.ones(1, 1, 1, 1), requires_grad=False)] + + [ + nn.Parameter( + torch.zeros(1, num_inputs + 1, self.C * self.F, 1), + requires_grad=True, + ) + for num_inputs in range(1, len(stages_setup) + 1) + ] + ) + + self.stem = nn.Sequential( + nn.Conv2d(1, int(cur_c), kernel_size=3, stride=1, padding="same"), + LayerNorm(int(cur_c), eps=1e-6, data_format="channels_first"), + ) + + Block1d = functools.partial(TimeContextBlock1d, block_type=self.block_1d_type) + Block2d = functools.partial(ConvBlock2d, block_type=self.block_2d_type) + + self.stages_cfs = [] + for stage_ind, ( + stride, + num_blocks, + conv_exp, + kernel_sizes, # TODO: Why the kernel_sizes are not used? + att_block_red, + ) in enumerate(stages_setup): + assert stride in [1, 2, 3] + # Pool frequencies & expand channels if needed + layers = [ + nn.Conv2d( + int(cur_c), + int(stride * cur_c * conv_exp), + kernel_size=(stride, 1), + stride=(stride, 1), + padding=0, + groups=1, + ), + ] + + self.stages_cfs.append((cur_c, cur_f)) + + cur_c = stride * cur_c + assert cur_f % stride == 0 + cur_f = cur_f // stride + + for _ in range(num_blocks): + # ConvBlock2d(f, c, block_type="convnext_like", group_divisor=1) + layers.append( + Block2d( + c=int(cur_c * conv_exp), f=cur_f, group_divisor=group_divisor + ) + ) + + if conv_exp != 1: + # Squeeze back channels to align with ReDimNet c+f reshaping: + _group_divisor = group_divisor + # if c // group_divisor == 0: + # _group_divisor = c + layers.append( + nn.Sequential( + nn.Conv2d( + int(cur_c * conv_exp), + cur_c, + kernel_size=(3, 3), + stride=1, + padding="same", + groups=( + cur_c // _group_divisor + if _group_divisor is not None + else 1 + ), + ), + nn.BatchNorm2d( + cur_c, + eps=1e-6, + ), + nn.GELU(), + nn.Conv2d(cur_c, cur_c, 1), + ) + ) + + layers.append(to1d()) + + # reduce block? + if att_block_red is not None: + layers.append( + Block1d(self.C * self.F, hC=(self.C * self.F) // att_block_red) + ) + + setattr(self, f"stage{stage_ind}", nn.Sequential(*layers)) + + if out_channels is not None: + self.mfa = nn.Sequential( + nn.Conv1d(self.F * self.C, out_channels, kernel_size=1, padding="same"), + nn.BatchNorm1d(out_channels, affine=True), + ) + else: + self.mfa = nn.Identity() + + def to1d(self, x): + size = x.size() + bs, c, f, t = tuple(size) + return x.permute((0, 2, 1, 3)).reshape((bs, c * f, t)) + + def to2d(self, x, c, f): + size = x.size() + bs, cf, t = tuple(size) + return x.reshape((bs, f, c, t)).permute((0, 2, 1, 3)) + + def weigth1d(self, outs_1d, i): + xs = torch.cat([t.unsqueeze(1) for t in outs_1d], dim=1) + w = F.softmax(self.inputs_weights[i], dim=1) + x = (w * xs).sum(dim=1) + return x + + def run_stage(self, prev_outs_1d, stage_ind): + stage = getattr(self, f"stage{stage_ind}") + c, f = self.stages_cfs[stage_ind] + + x = self.weigth1d(prev_outs_1d, stage_ind) + x = self.to2d(x, c, f) + x = stage(x) + return x + + def forward(self, inp): + x = self.stem(inp) + outputs_1d = [self.to1d(x)] + for stage_ind in range(self.num_stages): + outputs_1d.append(self.run_stage(outputs_1d, stage_ind)) + x = self.weigth1d(outputs_1d, -1) + x = self.mfa(x) + return x + + +class ReDimNet(nn.Module): + def __init__( + self, + feat_dim=72, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + # Default setup: M version: + stages_setup=( + # stride, num_blocks, kernel_sizes, layer_ext, att_block_red + (1, 2, 1, [(3, 3)], 12), + (2, 2, 1, [(3, 3)], 12), + (1, 3, 1, [(3, 3)], 12), + (2, 4, 1, [(3, 3)], 8), + (1, 4, 1, [(3, 3)], 8), + (2, 4, 1, [(3, 3)], 4), + ), + group_divisor=4, + out_channels=None, + # ------------------------- + embed_dim=192, + pooling_func="ASTP", + global_context_att=True, + two_emb_layer=False, + ): + + super().__init__() + self.two_emb_layer = two_emb_layer + self.backbone = ReDimNetBone( + feat_dim, + C, + block_1d_type, + block_2d_type, + stages_setup, + group_divisor, + out_channels, + ) + + if out_channels is None: + out_channels = C * feat_dim + + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=out_channels, global_context_att=global_context_att + ) + + self.pool_out_dim = self.pool.get_out_dim() + self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) + self.seg_2 = nn.Linear(embed_dim, embed_dim) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def forward(self, x): + # x = self.spec(x).unsqueeze(1) + x = x.permute(0, 2, 1) # (B,F,T) => (B,T,F) + x = x.unsqueeze_(1) + out = self.backbone(x) + + stats = self.pool(out) + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_a, embed_b + else: + return torch.tensor(0.0), embed_a + + +def ReDimNetB0(feat_dim=60, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=10, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=[ + (1, 2, 1, [(3, 3)], 30), + (2, 3, 2, [(3, 3)], 30), + (1, 3, 3, [(3, 3)], 30), + (2, 4, 2, [(3, 3)], 10), + (1, 3, 1, [(3, 3)], 10), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB1(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=12, + block_1d_type="conv+att", + block_2d_type="convnext_like", + stages_setup=[ + (1, 2, 1, [(3, 3)], None), + (2, 3, 1, [(3, 3)], None), + (3, 4, 1, [(3, 3)], 12), + (2, 5, 1, [(3, 3)], 12), + (2, 3, 1, [(3, 3)], 8), + ], + group_divisor=8, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB2(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=16, + block_1d_type="conv+att", + block_2d_type="convnext_like", + stages_setup=[ + (1, 2, 1, [(3, 3)], 12), + (2, 2, 1, [(3, 3)], 12), + (1, 3, 1, [(3, 3)], 12), + (2, 4, 1, [(3, 3)], 8), + (1, 4, 1, [(3, 3)], 8), + (2, 4, 1, [(3, 3)], 4), + ], + group_divisor=4, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB3(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 6, 4, [(3, 3)], 32), + (2, 6, 2, [(3, 3)], 32), + (1, 8, 2, [(3, 3)], 32), + (2, 10, 2, [(3, 3)], 16), + (1, 10, 1, [(3, 3)], 16), + (2, 8, 1, [(3, 3)], 16), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB4(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 4, 2, [(3, 3)], 48), + (2, 4, 2, [(3, 3)], 48), + (1, 6, 2, [(3, 3)], 48), + (2, 6, 1, [(3, 3)], 32), + (1, 8, 1, [(3, 3)], 24), + (2, 4, 1, [(3, 3)], 16), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB5(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 4, 2, [(3, 3)], 48), + (2, 4, 2, [(3, 3)], 48), + (1, 6, 2, [(3, 3)], 48), + (2, 6, 1, [(3, 3)], 32), + (1, 8, 1, [(3, 3)], 24), + (2, 4, 1, [(3, 3)], 16), + ], + group_divisor=16, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=[ + (1, 4, 4, [(3, 3)], 32), + (2, 6, 2, [(3, 3)], 32), + (1, 6, 2, [(3, 3)], 24), + (3, 8, 1, [(3, 3)], 24), + (1, 8, 1, [(3, 3)], 16), + (2, 8, 1, [(3, 3)], 16), + ], + group_divisor=32, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +if __name__ == "__main__": + x = torch.zeros(1, 200, 72) + model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False) + model.eval() + out = model(x) + print(out[-1].size()) + + num_params = sum(p.numel() for p in model.parameters()) + print("{} M".format(num_params / 1e6)) + + # Currently, the model sizes differ from the ones in the paper + model_classes = [ + ReDimNetB0, # 1.0M v.s. 1.0M + ReDimNetB1, # 2.1M v.s. 2.2M + ReDimNetB2, # 4.9M v.s. 4.7M + ReDimNetB3, # 3.2M v.s. 3.0M + ReDimNetB4, # 6.4M v.s. 6.3M + ReDimNetB5, # 7.65M v.s. 9.2M + ReDimNetB6, # 15.0M v.s. 15.0M + ] + + for i, model_class in enumerate(model_classes): + model = model_class() + num_params = sum(p.numel() for p in model.parameters()) + print("{} M of Model B{}".format(num_params / 1e6, i)) diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index b08fbda8..d0c10949 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -21,6 +21,8 @@ import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net import wespeaker.models.whisper_PMFA as whisper_PMFA +import wespeaker.models.redimnet as redimnet + def get_speaker_model(model_name: str): @@ -42,6 +44,8 @@ def get_speaker_model(model_name: str): return getattr(gemini, model_name) elif model_name.startswith("whisper_PMFA"): return getattr(whisper_PMFA, model_name) + elif model_name.startswith("ReDimNet"): + return getattr(redimnet, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1)