From bdc280491ff1ce8c2f2ac20ff8e363d7308a26b8 Mon Sep 17 00:00:00 2001 From: Aurora1818 Date: Thu, 29 Aug 2024 16:36:51 +0800 Subject: [PATCH] Support Whisper-PMFA --- .gitignore | 1 + examples/voxceleb/README.md | 2 + examples/voxceleb/v1/Whisper-PMFA/README.md | 24 ++ .../conf/whisper_PMFA_stage0.yaml | 78 +++++ .../conf/whisper_PMFA_stage1.yaml | 77 +++++ .../v1/Whisper-PMFA/local/download_data.sh | 56 ++++ .../v1/Whisper-PMFA/local/download_whisper.sh | 13 + .../v1/Whisper-PMFA/local/extract_vox.sh | 51 ++++ .../v1/Whisper-PMFA/local/prepare_data.sh | 89 ++++++ .../voxceleb/v1/Whisper-PMFA/local/score.sh | 57 ++++ .../v1/Whisper-PMFA/local/score_norm.sh | 69 +++++ examples/voxceleb/v1/Whisper-PMFA/path.sh | 5 + examples/voxceleb/v1/Whisper-PMFA/run.sh | 129 +++++++++ examples/voxceleb/v1/Whisper-PMFA/tools | 1 + examples/voxceleb/v1/Whisper-PMFA/wespeaker | 1 + wespeaker/bin/extract.py | 2 +- wespeaker/bin/train.py | 11 +- wespeaker/frontend/__init__.py | 3 +- wespeaker/frontend/whisper_encoder.py | 273 ++++++++++++++++++ wespeaker/models/speaker_model.py | 3 + wespeaker/models/whisper_PMFA.py | 125 ++++++++ 21 files changed, 1062 insertions(+), 8 deletions(-) create mode 100644 examples/voxceleb/v1/Whisper-PMFA/README.md create mode 100644 examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage0.yaml create mode 100644 examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage1.yaml create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/download_data.sh create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/download_whisper.sh create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/extract_vox.sh create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/prepare_data.sh create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/score.sh create mode 100755 examples/voxceleb/v1/Whisper-PMFA/local/score_norm.sh create mode 100644 examples/voxceleb/v1/Whisper-PMFA/path.sh create mode 100644 examples/voxceleb/v1/Whisper-PMFA/run.sh create mode 120000 examples/voxceleb/v1/Whisper-PMFA/tools create mode 120000 examples/voxceleb/v1/Whisper-PMFA/wespeaker create mode 100644 wespeaker/frontend/whisper_encoder.py create mode 100644 wespeaker/models/whisper_PMFA.py diff --git a/.gitignore b/.gitignore index f1181632..50baa8d6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ tensorboard external_tools pretrained_models s3prl_hub +whisper_hub diff --git a/examples/voxceleb/README.md b/examples/voxceleb/README.md index ad419ccf..1a8fd909 100644 --- a/examples/voxceleb/README.md +++ b/examples/voxceleb/README.md @@ -1,6 +1,8 @@ This is a **WeSpeaker** recipe for the Voxceleb 1&2 dataset. VoxCeleb is an audio-visual dataset consisting of short clips of human speech, extracted from interview videos uploaded to YouTube. See https://www.robots.ox.ac.uk/~vgg/data/voxceleb/ for more detailed information. The following recipes are provided: +* v1: **Fully-Supervised** train on Voxceleb 1 development set and evaluate on Voxceleb1-O trials. + * v2: **Fully-Supervised** train on Voxceleb 2 development set and evaluate on three official trials. * v2_deprecated: Deprecated version of fully-supervised train on Voxceleb dataset (deprecated IO). diff --git a/examples/voxceleb/v1/Whisper-PMFA/README.md b/examples/voxceleb/v1/Whisper-PMFA/README.md new file mode 100644 index 00000000..e88e4eac --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/README.md @@ -0,0 +1,24 @@ +## Results + +* Setup: mel80, num_frms500, epoch8, ArcMargin, aug_prob0.6, speed_perturb (no spec_aug) + +* Scoring: cosine (sub mean of vox1_dev), AS-Norm + +* Metric: EER(%) + +* 🔥 UPDATE 2024.08: We support Whisper based speaker verification framework Whisper-PMFA. Related papers: + + * [Whisper-PMFA: Partial Multi-Scale Feature Aggregation for Speaker Verification using Whisper Models ](https://arxiv.org/pdf/2408.15585) + + + +| Model | AS-Norm | Params | vox1-O-clean | +| :----------------------------------- | ------- | ------ | :----------: | +| ECAPA_TDNN_GLOB_c512-ASTP-emb192 | × | 6.19M | 2.23 | +| | √ | 6.19M | 2.00 | +| ResNet34-TSTP-emb256 | × | 6.63M | 1.99 | +| | √ | 6.63M | 1.88 | +| Whisper-PMFA | × | 478.7M | 1.62 | +| | √ | 478.7M | **1.42** | +| Whisper-PMFA with LoRa (Coming soon) | √ | 10.9M | 1.62 | + diff --git a/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage0.yaml b/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage0.yaml new file mode 100644 index 00000000..47b875bb --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage0.yaml @@ -0,0 +1,78 @@ +### train configuraton + +exp_dir: exp/test +gpus: "[0,1]" +num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 4 +save_epoch_interval: 1 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 70 + num_workers: 12 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + shuffle: True + shuffle_args: + shuffle_size: 2500 + resample_rate: 16000 + speed_perturb: True + num_frms: 500 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: whisper_encoder + whisper_encoder_args: + frozen: True + n_mels: 80 + num_blocks: 24 + output_size: 1280 + n_head: 20 + layer_st: 16 + layer_ed: 23 + model_path: whisper_hub/large-v2.pt + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: Whisper_PMFA_large_v2 +model_init: null +model_args: + embed_dim: 192 +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.2 + final_margin: 0.2 + increase_start_epoch: 0 + fix_start_epoch: 30 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.0025 + final_lr: 0.00113 + warm_up_epoch: 0 + warm_from_zero: False diff --git a/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage1.yaml b/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage1.yaml new file mode 100644 index 00000000..aa936979 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/conf/whisper_PMFA_stage1.yaml @@ -0,0 +1,77 @@ +### train configuraton + +exp_dir: exp/test +gpus: "[0,1]" +num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 8 +save_epoch_interval: 1 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 15 + num_workers: 12 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + shuffle: True + shuffle_args: + shuffle_size: 2500 + resample_rate: 16000 + speed_perturb: True + num_frms: 500 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: whisper_encoder + whisper_encoder_args: + frozen: False + n_mels: 80 + num_blocks: 24 + output_size: 1280 + n_head: 20 + layer_st: 16 + layer_ed: 23 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: Whisper_PMFA_large_v2 +model_init: null +model_args: + embed_dim: 192 +projection_args: + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.2 + final_margin: 0.2 + increase_start_epoch: 0 + fix_start_epoch: 30 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.0025 + final_lr: 0.00073 + warm_up_epoch: 0 + warm_from_zero: False diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/download_data.sh b/examples/voxceleb/v1/Whisper-PMFA/local/download_data.sh new file mode 100755 index 00000000..61f58914 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/download_data.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +download_dir=data/download_data + +. tools/parse_options.sh || exit 1 + +[ ! -d ${download_dir} ] && mkdir -p ${download_dir} + +if [ ! -f ${download_dir}/musan.tar.gz ]; then + echo "Downloading musan.tar.gz ..." + wget --no-check-certificate https://openslr.elda.org/resources/17/musan.tar.gz -P ${download_dir} + md5=$(md5sum ${download_dir}/musan.tar.gz | awk '{print $1}') + [ $md5 != "0c472d4fc0c5141eca47ad1ffeb2a7df" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1 +fi + +if [ ! -f ${download_dir}/rirs_noises.zip ]; then + echo "Downloading rirs_noises.zip ..." + wget --no-check-certificate https://us.openslr.org/resources/28/rirs_noises.zip -P ${download_dir} + md5=$(md5sum ${download_dir}/rirs_noises.zip | awk '{print $1}') + [ $md5 != "e6f48e257286e05de56413b4779d8ffb" ] && echo "Wrong md5sum of rirs_noises.zip" && exit 1 +fi + +if [ ! -f ${download_dir}/vox1_test_wav.zip ]; then + echo "Downloading vox1_test_wav.zip ..." + wget --no-check-certificate https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip -P ${download_dir} + md5=$(md5sum ${download_dir}/vox1_test_wav.zip | awk '{print $1}') + [ $md5 != "185fdc63c3c739954633d50379a3d102" ] && echo "Wrong md5sum of vox1_test_wav.zip" && exit 1 +fi + +if [ ! -f ${download_dir}/vox1_dev_wav.zip ]; then + echo "Downloading vox1_dev_wav.zip ..." + for part in a b c d; do + wget --no-check-certificate https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_parta${part} -P ${download_dir} & + done + wait + cat ${download_dir}/vox1_dev* >${download_dir}/vox1_dev_wav.zip + md5=$(md5sum ${download_dir}/vox1_dev_wav.zip | awk '{print $1}') + [ $md5 != "ae63e55b951748cc486645f532ba230b" ] && echo "Wrong md5sum of vox1_dev_wav.zip" && exit 1 +fi + + +echo "Download success !!!" diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/download_whisper.sh b/examples/voxceleb/v1/Whisper-PMFA/local/download_whisper.sh new file mode 100755 index 00000000..d0bf7a6b --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/download_whisper.sh @@ -0,0 +1,13 @@ +download_dir=data/whisper_pretrained_model + +. tools/parse_options.sh || exit 1 + +[ ! -d ${download_dir} ] && mkdir -p ${download_dir} + +if [ ! -f ${download_dir}/large-v2.pt ]; then + echo "Downloading large-v2.pt ..." + wget --no-check-certificate https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt -P ${download_dir} + md5=$(md5sum ${download_dir}/large-v2.pt | awk '{print $1}') + [ $md5 != "668764447eeda98eeba5ef7bfcb4cc3d" ] && echo "Wrong md5sum of musan.tar.gz" && exit 1 +fi + diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/extract_vox.sh b/examples/voxceleb/v1/Whisper-PMFA/local/extract_vox.sh new file mode 100755 index 00000000..613012c1 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/extract_vox.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exp_dir='' +model_path='' +nj=4 +gpus="[0,1]" +data_type="shard" # shard/raw/feat +data=data + +. tools/parse_options.sh +set -e + +data_name_array=("vox1_dev" "vox1_test") +data_list_path_array=("${data}/vox1_dev/${data_type}.list" "${data}/vox1_test/${data_type}.list") +data_scp_path_array=("${data}/vox1_dev/wav.scp" "${data}/vox1_test/wav.scp") # to count the number of wavs +nj_array=($nj $nj) +batch_size_array=(16 1) # batch_size of test set must be 1 !!! +num_workers_array=(4 1) +count=${#data_name_array[@]} + +for i in $(seq 0 $(($count - 1))); do + wavs_num=$(wc -l ${data_scp_path_array[$i]} | awk '{print $1}') + bash tools/extract_embedding.sh --exp_dir ${exp_dir} \ + --model_path $model_path \ + --data_type ${data_type} \ + --data_list ${data_list_path_array[$i]} \ + --wavs_num ${wavs_num} \ + --store_dir ${data_name_array[$i]} \ + --batch_size ${batch_size_array[$i]} \ + --num_workers ${num_workers_array[$i]} \ + --nj ${nj_array[$i]} \ + --gpus $gpus & +done + +wait + +echo "Embedding dir is (${exp_dir}/embeddings)." diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/prepare_data.sh b/examples/voxceleb/v1/Whisper-PMFA/local/prepare_data.sh new file mode 100755 index 00000000..6b55499e --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/prepare_data.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +stage=-1 +stop_stage=-1 +data=data + +. tools/parse_options.sh || exit 1 + +data=`realpath ${data}` +download_dir=${data}/download_data +rawdata_dir=${data}/raw_data + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Download musan.tar.gz, rirs_noises.zip, vox1_test_wav.zip, and vox1_dev_wav.zip." + echo "This may take a long time. Thus we recommand you to download all archives above in your own way first." + + ./local/download_data.sh --download_dir ${download_dir} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Decompress all archives ..." + echo "This could take some time ..." + + for archive in musan.tar.gz rirs_noises.zip vox1_test_wav.zip vox1_dev_wav.zip; do + [ ! -f ${download_dir}/$archive ] && echo "Archive $archive not exists !!!" && exit 1 + done + [ ! -d ${rawdata_dir} ] && mkdir -p ${rawdata_dir} + + if [ ! -d ${rawdata_dir}/musan ]; then + tar -xzvf ${download_dir}/musan.tar.gz -C ${rawdata_dir} + fi + + if [ ! -d ${rawdata_dir}/RIRS_NOISES ]; then + unzip ${download_dir}/rirs_noises.zip -d ${rawdata_dir} + fi + + if [ ! -d ${rawdata_dir}/voxceleb1 ]; then + mkdir -p ${rawdata_dir}/voxceleb1/test ${rawdata_dir}/voxceleb1/dev + unzip ${download_dir}/vox1_test_wav.zip -d ${rawdata_dir}/voxceleb1/test + unzip ${download_dir}/vox1_dev_wav.zip -d ${rawdata_dir}/voxceleb1/dev + fi + + echo "Decompress success !!!" +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Prepare wav.scp for each dataset ..." + export LC_ALL=C # kaldi config + + mkdir -p ${data}/musan ${data}/rirs ${data}/vox1_dev ${data}/vox1_test + # musan + find ${rawdata_dir}/musan -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/musan/wav.scp + # rirs + find ${rawdata_dir}/RIRS_NOISES/simulated_rirs -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' >${data}/rirs/wav.scp + # vox1 dev + find ${rawdata_dir}/voxceleb1/dev -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' | sort >${data}/vox1_dev/wav.scp + awk '{print $1}' ${data}/vox1_dev/wav.scp | awk -F "/" '{print $0,$1}' >${data}/vox1_dev/utt2spk + ./tools/utt2spk_to_spk2utt.pl ${data}/vox1_dev/utt2spk >${data}/vox1_dev/spk2utt + # vox1 test + find ${rawdata_dir}/voxceleb1/test -name "*.wav" | awk -F"/" '{print $(NF-2)"/"$(NF-1)"/"$NF,$0}' | sort >${data}/vox1_test/wav.scp + awk '{print $1}' ${data}/vox1_test/wav.scp | awk -F "/" '{print $0,$1}' >${data}/vox1_test/utt2spk + ./tools/utt2spk_to_spk2utt.pl ${data}/vox1_test/utt2spk >${data}/vox1_test/spk2utt + + if [ ! -d ${data}/vox1_test/trials ]; then + echo "Download trials for vox1_test ..." + mkdir -p ${data}/vox1_test/trials + #wget --no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt -O ${data}/vox1_test/trials/vox1-O.txt + wget --no-check-certificate https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt -O ${data}/vox1_test/trials/vox1-O\(cleaned\).txt + # transform them into kaldi trial format + awk '{if($1==0)label="nontarget";else{label="target"}; print $2,$3,label}' ${data}/vox1_test/trials/vox1-O\(cleaned\).txt >${data}/vox1_test/trials/vox1_O_cleaned.kaldi + fi + + echo "Success !!!" +fi diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/score.sh b/examples/voxceleb/v1/Whisper-PMFA/local/score.sh new file mode 100755 index 00000000..5b81a883 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/score.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# Copyright (c) 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exp_dir= +trials="vox1_O_cleaned.kaldi vox1_E_cleaned.kaldi vox1_H_cleaned.kaldi" +data=data + +stage=-1 +stop_stage=-1 + +. tools/parse_options.sh +. path.sh + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "apply cosine scoring ..." + mkdir -p ${exp_dir}/scores + trials_dir=${data}/vox1_test/trials + for x in $trials; do + echo $x + python wespeaker/bin/score.py \ + --exp_dir ${exp_dir} \ + --eval_scp_path ${exp_dir}/embeddings/vox1_test/xvector.scp \ + --cal_mean True \ + --cal_mean_dir ${exp_dir}/embeddings/vox1_dev \ + ${trials_dir}/${x} + done +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "compute metrics (EER/minDCF) ..." + scores_dir=${exp_dir}/scores + for x in $trials; do + python wespeaker/bin/compute_metrics.py \ + --p_target 0.01 \ + --c_fa 1 \ + --c_miss 1 \ + ${scores_dir}/${x}.score \ + 2>&1 | tee -a ${scores_dir}/vox1_cos_result + + echo "compute DET curve ..." + python wespeaker/bin/compute_det.py \ + ${scores_dir}/${x}.score + done +fi diff --git a/examples/voxceleb/v1/Whisper-PMFA/local/score_norm.sh b/examples/voxceleb/v1/Whisper-PMFA/local/score_norm.sh new file mode 100755 index 00000000..73431093 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/local/score_norm.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Copyright (c) 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +score_norm_method="asnorm" # asnorm/snorm +cohort_set=vox2_dev +top_n=100 +exp_dir= +trials="vox1_O_cleaned.kaldi vox1_E_cleaned.kaldi vox1_H_cleaned.kaldi" +data=data + +stage=-1 +stop_stage=-1 + +. tools/parse_options.sh +. path.sh + + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "compute mean xvector" + python tools/vector_mean.py \ + --spk2utt ${data}/${cohort_set}/spk2utt \ + --xvector_scp $exp_dir/embeddings/${cohort_set}/xvector.scp \ + --spk_xvector_ark $exp_dir/embeddings/${cohort_set}/spk_xvector.ark +fi + +output_name=${cohort_set}_${score_norm_method} +[ "${score_norm_method}" == "asnorm" ] && output_name=${output_name}${top_n} +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "compute norm score" + for x in $trials; do + python wespeaker/bin/score_norm.py \ + --score_norm_method $score_norm_method \ + --top_n $top_n \ + --trial_score_file $exp_dir/scores/${x}.score \ + --score_norm_file $exp_dir/scores/${output_name}_${x}.score \ + --cohort_emb_scp ${exp_dir}/embeddings/${cohort_set}/spk_xvector.scp \ + --eval_emb_scp ${exp_dir}/embeddings/vox1_test/xvector.scp \ + --mean_vec_path ${exp_dir}/embeddings/vox1_dev/mean_vec.npy + done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "compute metrics" + for x in ${trials}; do + scores_dir=${exp_dir}/scores + python wespeaker/bin/compute_metrics.py \ + --p_target 0.01 \ + --c_fa 1 \ + --c_miss 1 \ + ${scores_dir}/${output_name}_${x}.score \ + 2>&1 | tee -a ${scores_dir}/vox1_${score_norm_method}${top_n}_result + + python wespeaker/bin/compute_det.py \ + ${scores_dir}/${output_name}_${x}.score + done +fi diff --git a/examples/voxceleb/v1/Whisper-PMFA/path.sh b/examples/voxceleb/v1/Whisper-PMFA/path.sh new file mode 100644 index 00000000..e7917ccb --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/path.sh @@ -0,0 +1,5 @@ +export PATH=$PWD:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PWD:$PYTHONPATH diff --git a/examples/voxceleb/v1/Whisper-PMFA/run.sh b/examples/voxceleb/v1/Whisper-PMFA/run.sh new file mode 100644 index 00000000..42b63023 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/run.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# Copyright 2022 Hongji Wang (jijijiang77@gmail.com) +# 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) +# 2022 Zhengyang Chen (chenzhengyang117@gmail.com) + +. ./path.sh || exit 1 + +stage=3 +stop_stage=3 + +data=data +data_type="raw" # shard/raw +model=whisper_PMFA_large_v2 + +exp_dir=exp/Whisper_PMFA_large_v2_voxceleb1_mel_5s + +gpus="[0]" +num_avg=10 +checkpoint= + +trials="vox1_O_cleaned.kaldi" + +score_norm_method="asnorm" # asnorm/snorm +top_n=300 + +. tools/parse_options.sh || exit 1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Preparing datasets ..." + ./local/prepare_data.sh --stage 1 --stop_stage 3 --data ${data} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Covert train and test data to ${data_type}..." + for dset in vox1_dev vox1_test; do + if [ $data_type == "shard" ]; then + python tools/make_shard_list.py --num_utts_per_shard 1000 \ + --num_threads 16 \ + --prefix shards \ + --shuffle \ + ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ + ${data}/$dset/shards ${data}/$dset/shard.list + else + python tools/make_raw_list.py ${data}/$dset/wav.scp \ + ${data}/$dset/utt2spk ${data}/$dset/raw.list + fi + done + # Convert all musan data to LMDB + python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb + # Convert all rirs data to LMDB + python tools/make_lmdb.py ${data}/rirs/wav.scp ${data}/rirs/lmdb +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Start training with frozen whisper parameter..." + config=conf/whisper_PMFA_stage0.yaml + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wespeaker/bin/train.py --config $config \ + --exp_dir ${exp_dir} \ + --gpus $gpus \ + --num_avg ${num_avg} \ + --data_type "${data_type}" \ + --train_data ${data}/vox1_dev/${data_type}.list \ + --train_label ${data}/vox1_dev/utt2spk \ + --reverb_data ${data}/rirs/lmdb \ + --noise_data ${data}/musan/lmdb \ + --model ${model} +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Start training with all parameter..." + + if [ -f ${exp_dir}/"config.yaml" ]; then + mv ${exp_dir}/"config.yaml" ${exp_dir}/"config_stage0.yaml" + fi + if [ -f ${exp_dir}/models/"final_model.pt" ]; then + mv ${exp_dir}/models/"final_model.pt" ${exp_dir}/models/"final_model_stage0.pt" + fi + + config=conf/whisper_PMFA_stage1.yaml + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + checkpoint=${exp_dir}/models/model_4.pt + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wespeaker/bin/train.py --config $config \ + --exp_dir ${exp_dir} \ + --gpus $gpus \ + --num_avg ${num_avg} \ + --data_type "${data_type}" \ + --train_data ${data}/vox1_dev/${data_type}.list \ + --train_label ${data}/vox1_dev/utt2spk \ + --reverb_data ${data}/rirs/lmdb \ + --noise_data ${data}/musan/lmdb \ + --model ${model} \ + --checkpoint ${checkpoint} +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + model_path=$exp_dir/models/final_model.pt + echo "Extract embeddings ..." + local/extract_vox.sh \ + --exp_dir $exp_dir --model_path $model_path \ + --nj 2 --gpus $gpus --data_type raw --data ${data} +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "Score ..." + local/score.sh \ + --stage 1 --stop-stage 2 \ + --exp_dir $exp_dir \ + --data ${data} \ + --trials "$trials" +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + echo "Score norm ..." + local/score_norm.sh \ + --stage 1 --stop-stage 3 \ + --score_norm_method $score_norm_method \ + --cohort_set vox1_dev \ + --top_n $top_n \ + --exp_dir $exp_dir \ + --data ${data} \ + --trials "$trials" +fi diff --git a/examples/voxceleb/v1/Whisper-PMFA/tools b/examples/voxceleb/v1/Whisper-PMFA/tools new file mode 120000 index 00000000..dc53b6d1 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/tools @@ -0,0 +1 @@ +/nfs/data/raid02/user/zhaoyiyang/git/wespeaker/tools \ No newline at end of file diff --git a/examples/voxceleb/v1/Whisper-PMFA/wespeaker b/examples/voxceleb/v1/Whisper-PMFA/wespeaker new file mode 120000 index 00000000..2bd78bf0 --- /dev/null +++ b/examples/voxceleb/v1/Whisper-PMFA/wespeaker @@ -0,0 +1 @@ +/nfs/data/raid02/user/zhaoyiyang/git/wespeaker/wespeaker \ No newline at end of file diff --git a/wespeaker/bin/extract.py b/wespeaker/bin/extract.py index dfa5fdac..032d98c3 100644 --- a/wespeaker/bin/extract.py +++ b/wespeaker/bin/extract.py @@ -47,7 +47,7 @@ def extract(config='conf/config.yaml', **kwargs): # model: frontend (optional) => speaker model model = get_speaker_model(configs['model'])(**configs['model_args']) frontend_type = test_conf.get('frontend', 'fbank') - if frontend_type == 's3prl': + if frontend_type != 'fbank': frontend_args = frontend_type + "_args" print('Initializing frontend model (this could take some time) ...') frontend = frontend_class_dict[frontend_type]( diff --git a/wespeaker/bin/train.py b/wespeaker/bin/train.py index 3c534353..2918ea97 100644 --- a/wespeaker/bin/train.py +++ b/wespeaker/bin/train.py @@ -109,18 +109,17 @@ def train(config='conf/config.yaml', **kwargs): logger.info("<== Model ==>") # frontend: fbank or s3prl frontend_type = configs['dataset_args'].get('frontend', 'fbank') - if frontend_type == 's3prl': + if frontend_type != "fbank": frontend_args = frontend_type + "_args" frontend = frontend_class_dict[frontend_type]( - **configs['dataset_args'][frontend_args], - sample_rate=configs['dataset_args']['resample_rate']) - # speaker model + **configs['dataset_args'][frontend_args], + sample_rate=configs['dataset_args']['resample_rate']) configs['model_args']['feat_dim'] = frontend.output_size() model = get_speaker_model(configs['model'])(**configs['model_args']) model.add_module("frontend", frontend) - else: # == 'fbank' - # speaker model + else: model = get_speaker_model(configs['model'])(**configs['model_args']) + if rank == 0: num_params = sum(param.numel() for param in model.parameters()) logger.info('speaker_model size: {}'.format(num_params)) diff --git a/wespeaker/frontend/__init__.py b/wespeaker/frontend/__init__.py index 9b9fd27b..a210b9d3 100644 --- a/wespeaker/frontend/__init__.py +++ b/wespeaker/frontend/__init__.py @@ -14,5 +14,6 @@ from .s3prl import S3prlFrontend +from .whisper_encoder import whisper_encoder -frontend_class_dict = {'s3prl' : S3prlFrontend} +frontend_class_dict = {'fbank' : None, 's3prl' : S3prlFrontend, 'whisper_encoder' : whisper_encoder} diff --git a/wespeaker/frontend/whisper_encoder.py b/wespeaker/frontend/whisper_encoder.py new file mode 100644 index 00000000..4faeeeec --- /dev/null +++ b/wespeaker/frontend/whisper_encoder.py @@ -0,0 +1,273 @@ +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torch import Tensor +from torch import nn + +from typing import Iterable, Optional +import wespeaker.models.pooling_layers as pooling_layers + +import os +import hashlib +import whisper +import logging +import urllib.request + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, layer_st: int, layer_ed: int): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + # self.ln_post = LayerNorm(n_state) + # ------------------------ADD:add new layer norm----------------------------- + self.ln_post2 = LayerNorm(n_state * (layer_ed - layer_st + 1)) + + self.layer_st = layer_st + self.layer_ed = layer_ed + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + # ---------------------------ADD------------------------ + x = x.permute(0, 2, 1) + + x = x.squeeze(1) + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + # ----------------------Change:Tailor the positional_embedding------------------------- + assert x.shape[2:] == self.positional_embedding.shape[1:], "incorrect audio shape" + if self.positional_embedding.shape[0] > x.shape[1]: + temp_positional_embedding = self.positional_embedding[:x.shape[1], :] + elif self.positional_embedding.shape[0] < x.shape[1]: + x = x[:,:self.positional_embedding.shape[0],:] + temp_positional_embedding = self.positional_embedding + else: + temp_positional_embedding = self.positional_embedding + + x = (x + temp_positional_embedding).to(x.dtype) + + # ------------------------------------Change: Concat block outputs------------------------------------ + out = [] + for i, block in enumerate(self.blocks): + x = block(x) + if self.layer_st <= i <= self.layer_ed: + out.append(x) + + xs = torch.cat(out, dim=-1) + + xs = self.ln_post2(xs) + return xs + + +class whisper_encoder(torch.nn.Module): + def __init__(self, + frozen=False, + n_mels=80, + num_blocks=24, + output_size=1280, + n_head=20, + layer_st=16, + layer_ed=23, + model_path=None, + sample_rate=16000 + ): + super(whisper_encoder, self).__init__() + self.encoder = AudioEncoder(n_mels=n_mels, n_layer=num_blocks, n_state=output_size, n_ctx=1500, + n_head=n_head, layer_st=layer_st, layer_ed=layer_ed) + # 0 for freeze finetune, 1 for all parameters finetune + self.frozen = frozen + self.single_output_size = output_size + self.concat_layer = layer_ed - layer_st + 1 + self.n_mels = n_mels + + # load model + if model_path: + if dist.is_initialized(): + if dist.get_rank() == 0: + self._download_whisper_model(model_path) + dist.barrier() # Wait for rank 0 to finish downloading + self._load_pretrained_weights(model_path) + else: + self._download_whisper_model(model_path) + self._load_pretrained_weights(model_path) + + if self.frozen: + for param in self.encoder.parameters(): + param.requires_grad_(False) + + + def _download_whisper_model(self, model_path='whisper_hub/large-v2.pt'): + download_dir = os.path.dirname(model_path) + + if not os.path.exists(download_dir): + os.makedirs(download_dir) + + if not os.path.isfile(model_path): + print("Downloading large-v2.pt ...") + url = 'https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt' + + urllib.request.urlretrieve(url, model_path) + + md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest() + + if md5 != "668764447eeda98eeba5ef7bfcb4cc3d": + print("Wrong md5sum of large-v2.pt") + os.remove(model_path) + raise ValueError("MD5 checksum does not match!") + else: + print("Model already downloaded.") + + + def _load_pretrained_weights(self, model_path): + print(f"Loading pretrained weights from {model_path}...") + + state_dict = torch.load(model_path, map_location=torch.device('cpu')) + state_dict = state_dict['model_state_dict'] + + new_state_dict = {} + for k, v in state_dict.items(): + new_key = k.replace('encoder.', '', 1) + new_state_dict[new_key] = v + + missing_keys, unexpected_keys = self.encoder.load_state_dict(new_state_dict,strict=False) + print("Pretrained weights loaded successfully.") + for key in missing_keys: + logging.warning('missing tensor: {}'.format(key)) + for key in unexpected_keys: + logging.warning('unexpected tensor: {}'.format(key)) + + + def output_size(self): + return self.single_output_size*self.concat_layer + + def forward(self, wavs, wavs_len): + with torch.no_grad(): + processed_feats = [] + for i in range(wavs.size(0)): + tf_tensor = wavs[i].unsqueeze(0).to(wavs.device) + mat = whisper.log_mel_spectrogram(tf_tensor.squeeze(), n_mels=self.n_mels) + processed_feats.append(mat) + + feat = torch.stack(processed_feats, dim=0).to(wavs.device) + + feat = feat.transpose(1, 2) + # (B,T,F) + x = self.encoder(feat) + return x, None + + diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 8475f1ae..b08fbda8 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -20,6 +20,7 @@ import wespeaker.models.eres2net as eres2net import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net +import wespeaker.models.whisper_PMFA as whisper_PMFA def get_speaker_model(model_name: str): @@ -39,6 +40,8 @@ def get_speaker_model(model_name: str): return getattr(res2net, model_name) elif model_name.startswith("Gemini"): return getattr(gemini, model_name) + elif model_name.startswith("whisper_PMFA"): + return getattr(whisper_PMFA, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1) diff --git a/wespeaker/models/whisper_PMFA.py b/wespeaker/models/whisper_PMFA.py new file mode 100644 index 00000000..09f878f9 --- /dev/null +++ b/wespeaker/models/whisper_PMFA.py @@ -0,0 +1,125 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn + +import wespeaker.models.pooling_layers as pooling_layers + +class BatchNorm1d(nn.Module): + """Applies 1d batch normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + + + Example + ------- + >>> input = torch.randn(100, 10) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + combine_batch_time=False, + skip_transpose=True, + ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = nn.BatchNorm1d( + input_size, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape( + shape_or[0] * shape_or[1], shape_or[3], shape_or[2] + ) + + elif not self.skip_transpose: + x = x.transpose(-1, 1) + + x_n = self.norm(x) + + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose(1, -1) + + return x_n + + +class whisper_PMFA(torch.nn.Module): + def __init__(self, output_size=1280, embedding_dim=192, pooling_func='ASTP',global_context_att=True): + super(whisper_PMFA, self).__init__() + self.pooling = getattr(pooling_layers, pooling_func)( + in_dim=output_size, global_context_att=global_context_att) + self.bn = BatchNorm1d(input_size=output_size*2) + self.fc = torch.nn.Linear(output_size*2, embedding_dim) + + + def forward(self, x): + x = x.permute(0, 2, 1) + x = self.pooling(x) + x = x.unsqueeze(-1) + x = self.bn(x) + x = x.permute(0, 2, 1) + x = self.fc(x) + x = x.squeeze(1) + return x + + + + +def whisper_PMFA_large_v2(feat_dim, embed_dim): + return whisper_PMFA(output_size=feat_dim, + embedding_dim=embed_dim + ) \ No newline at end of file