From 13a6f0a63f843ed9ef7658c816a97d21c8a939fc Mon Sep 17 00:00:00 2001 From: genisplaja Date: Wed, 25 Oct 2023 12:28:17 +0200 Subject: [PATCH] v0.1 --- .gitignore | 188 +++----------------- LICENSE | 222 +++--------------------- README.md | 49 +++++- ckpt/README.md | 4 + config.py | 86 +++++++++ dataset/__init__.py | 1 + dataset/config.py | 44 +++++ dataset/prepare_saraga.py | 59 +++++++ dataset/saraga.py | 106 +++++++++++ model/__init__.py | 145 ++++++++++++++++ model/clustering.py | 20 +++ model/config.py | 86 +++++++++ model/estnoise_ms.py | 347 +++++++++++++++++++++++++++++++++++++ model/unet.py | 326 ++++++++++++++++++++++++++++++++++ model/unet_utils.py | 40 +++++ model/vad.py | 86 +++++++++ requirements.txt | 7 + separate.py | 115 ++++++++++++ train.py | 276 +++++++++++++++++++++++++++++ utils/noam_schedule.py | 32 ++++ utils/phase_vocoder.py | 59 +++++++ utils/separation_eval.py | 10 ++ utils/signal_processing.py | 74 ++++++++ 23 files changed, 2020 insertions(+), 362 deletions(-) create mode 100644 ckpt/README.md create mode 100644 config.py create mode 100644 dataset/__init__.py create mode 100644 dataset/config.py create mode 100644 dataset/prepare_saraga.py create mode 100644 dataset/saraga.py create mode 100644 model/__init__.py create mode 100644 model/clustering.py create mode 100644 model/config.py create mode 100644 model/estnoise_ms.py create mode 100644 model/unet.py create mode 100644 model/unet_utils.py create mode 100644 model/vad.py create mode 100644 requirements.txt create mode 100644 separate.py create mode 100644 train.py create mode 100644 utils/noam_schedule.py create mode 100644 utils/phase_vocoder.py create mode 100644 utils/separation_eval.py create mode 100644 utils/signal_processing.py diff --git a/.gitignore b/.gitignore index 68bc17f..fefb868 100644 --- a/.gitignore +++ b/.gitignore @@ -1,160 +1,28 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# python +__pycache__ +ckpt/__pycache__ +dataset/__pycache__ +model/__pycache__ +utils/__pycache__ + +# mac osX +.DS_Store +ckpt/.DS_Store +dataset/.DS_Store +model/.DS_Store +utils/.DS_Store + +# train +ckpt/saraga-8/ +ckpt/saraga-8.json +log + +# test sample +sample +output + +# testing stuff for paper +testing_files/ + +# wip for faster, cleaner, and better execution +evaluate.py \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9..d5af8fd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2020 YoungJoong Kim + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 0eb0e73..fc8e2ee 100644 --- a/README.md +++ b/README.md @@ -1 +1,48 @@ -# carnatic-separation-ismir23 \ No newline at end of file +# Carnatic singing voice separation trained with in-domain data with leakage +This is the official repository for: + +- Carnatic Singing Voice Separation Using Cold Diffusion on Training Data with Bleeding, G. Plaja-Roglans, M. Miron, A. Shankar and X. Serra, 2023 (accepted for presentation at ISMIR 2023, Milan, Italy). + +**Important note:** The code structure and an important part of the data loader and training code is an adaptation of the unofficial Tensorflow implementation of DiffWave (Zhifeng Kong et al., 2020). [Link to original repo](https://github.com/revsic/tf-diffwave). + +**Another important note:** The model in this repo can also be used for easy inference through the Python library [compIAM](https://github.com/MTG/compIAM), a centralized repository of tools, models, and datasets for the computational analysis of Carnatic and Hindustani Music. With a few commands, you can easily download and run the separation model. Refer to `compIAM` to use these model (and many others!) out-of-the-box. + +## Requirements + +The repository is based on Tensorflow 2. [See a complete list of requirements here](./requirements.txt). + +## Run separation inference + +To run separation inference you can use `separate.py` file. + +```bash +python3 separate.py --input-signal /path/to/file.wav --clusters 5 --scheduler 4 +``` + +Additional arguments can be passed to use a different model (`--model-name`), modify the batch size (i.e. chunk size processed by the model for optimized inference, `--batch-size`), and also specify to which GPU the process should be routed (`--gpu`). + +## Train the model + +To train your own model, you should first prepare the data. See [how we do process Saraga](./dataset/prepare_saraga.py) before the training process detailed in the paper. The key idea is to have the chunked and aligned audio samples of the dataset with a naming like: `_.wav`, where `` corresponds to `mixture` and `vocals`. + +Then, run model training in [train.py](./train.py). Checkpoints will be stored every X training steps, X is defined by user in the (./config.py) file. + +To start to train from previous checkpoint, `--load-step` is available. + +```bash +python .\train.py --load-step 416 --config ./ckpt/.json +``` + +Download the pre-trained weights for the feature extraction U-Net [here](https://drive.google.com/uc?export=download&id=1yj9iHTY7nCh2qrIM2RIUOXhLXt1K8WcE). + +Unzip and store the weights into the [ckpt folder](./ckpt/). There should be .json file with the configuration, and a folder with the TF weights inside. Here's an example: + +```py +with open('./ckpt/saraga-8.json') as f: + config = Config.load(json.load(f)) + +diffwave = DiffWave(config.model) +diffwave.restore('./ckpt/saraga-8/saraga-8.ckpt-1').expect_partial() +``` + +[Write us](mailto:genis.plaja@upf.edu) or open an issue if you have any issues or questions! \ No newline at end of file diff --git a/ckpt/README.md b/ckpt/README.md new file mode 100644 index 0000000..0b69373 --- /dev/null +++ b/ckpt/README.md @@ -0,0 +1,4 @@ +## Model weights +Download and store the model weights here. The folder structure, for a model named `saraga-8`, should look like that: +* `.ckpt/saraga-8/saraga-8.json`: config file for `saraga-8` model. +* `.ckpt/saraga-8/saraga-8/`: folder containing the checkpoint for the `saraga-8` model. \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..6245443 --- /dev/null +++ b/config.py @@ -0,0 +1,86 @@ +from dataset.config import Config as UnetDataConfig +from model.config import Config as UnetModelConfig +from utils.noam_schedule import NoamScheduler + + +class UnetTrainConfig: + """Configuration for training loop. + """ + def __init__(self): + # optimizer + self.lr_policy = 'fixed' + self.learning_rate = 0.000002 + # self.lr_policy = 'noam' + # self.learning_rate = 1 + # self.lr_params = { + # 'warmup_steps': 4000, + # 'channels': 64 + # } + + self.beta1 = 0.9 + self.beta2 = 0.98 + self.eps = 1e-9 + + # 13000:100 + self.split = 3001*8 +# self.split = 50 + self.bufsiz = 50 + + self.epoch = 10000 + + # path config + self.log = './log' + self.ckpt = './ckpt' + self.sounds = './sounds' + + # model name + self.model_type = None + self.name = 'saraga-8' + + # interval configuration + self.eval_intval = 5000 + self.ckpt_intval = 10000 + def lr(self): + """Generate proper learning rate scheduler. + """ + mapper = { + 'noam': NoamScheduler + } + if self.lr_policy == 'fixed': + return self.learning_rate + if self.lr_policy in mapper: + return mapper[self.lr_policy](self.learning_rate, **self.lr_params) + raise ValueError('invalid lr_policy') + +class Config(): + """Integrated configuration. + """ + def __init__(self): + self.data = UnetDataConfig() + self.model = UnetModelConfig() + self.train = UnetTrainConfig() + + def dump(self): + """Dump configurations into serializable dictionary. + """ + return {k: vars(v) for k, v in vars(self).items()} + + @staticmethod + def load(dump_): + """Load dumped configurations into new configuration. + """ + conf = Config() + for k, v in dump_.items(): + if hasattr(conf, k): + obj = getattr(conf, k) + load_state(obj, v) + return conf + + +def load_state(obj, dump_): + """Load dictionary items to attributes. + """ + for k, v in dump_.items(): + if hasattr(obj, k): + setattr(obj, k, v) + return obj diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..0cca756 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1 @@ +from .saraga import SARAGA \ No newline at end of file diff --git a/dataset/config.py b/dataset/config.py new file mode 100644 index 0000000..c4f4131 --- /dev/null +++ b/dataset/config.py @@ -0,0 +1,44 @@ +import tensorflow as tf + +class Config: + """Configuration for dataset construction. + """ + def __init__(self): + # audio config + self.sr = 22050 + + # stft + self.hop = 256 + self.win = 1024 + self.fft = self.win + self.win_fn = 'hann' + + # mel-scale filter bank + self.mel = 80 + self.fmin = 0 + self.fmax = 8000 + + self.eps = 1e-5 + + # sample size + self.frames = (self.hop + 6) * 128 # 16384 + self.batch = 8 + + self.eval_tracks = [ + 'kailasapathe', 'ragam_tanam_pallavi', 'gopi_gopala_bala', 'ananda_sagara' + ] # four manually selected tracks for validation + + def window_fn(self): + """Return window generator. + Returns: + Callable, window function of tf.signal + , which corresponds to self.win_fn. + """ + mapper = { + 'hann': tf.signal.hann_window, + 'hamming': tf.signal.hamming_window + } + if self.win_fn in mapper: + return mapper[self.win_fn] + + raise ValueError('invalid window function: ' + self.win_fn) diff --git a/dataset/prepare_saraga.py b/dataset/prepare_saraga.py new file mode 100644 index 0000000..af3ab85 --- /dev/null +++ b/dataset/prepare_saraga.py @@ -0,0 +1,59 @@ +import os +import argparse +import torch +import glob +import tqdm + +import numpy as np +import torch.nn.functional as F +import torchaudio as T + +SR = 22050 + + +def main(args): + concert = glob.glob(os.path.join(args.saraga_dir, '*/')) + + for i in tqdm(concert): + songs = glob.glob(os.path.join(args.saraga_dir, i, '*/')) + for j in tqdm.tqdm(songs): + song_name = j.split("/")[-2] + mixture = os.path.join(j, song_name + ".mp3.mp3") + vocals = os.path.join(j, song_name + ".multitrack-vocal.mp3") + + if os.path.exists(mixture): + audio_mix, sr = T.load(mixture) + audio_voc, _ = T.load(vocals) + resampling = T.transforms.Resample(sr, SR) + audio_mix = resampling(audio_mix) + audio_voc = resampling(audio_voc) + audio_mix = torch.mean(audio_mix, dim=0).unsqueeze(0) + audio_mix = torch.clamp(audio_mix, -1.0, 1.0) + audio_voc = torch.mean(audio_voc, dim=0).unsqueeze(0) + audio_voc = torch.clamp(audio_voc, -1.0, 1.0) + + actual_len = audio_voc.shape + for trim in np.arange(actual_len[1] // (args.sample_len*SR)): + T.save( + os.path.join( + args.output_dir, song_name.lower().replace(" ", "_") + '_' + str(trim) + '_mixture.wav'), + audio_mix[:, trim*args.sample_len*SR:(trim+1)*args.sample_len*SR].cpu(), + sample_rate=sr, + bits_per_sample=16) + T.save( + os.path.join( + args.output_dir, song_name.lower().replace(" ", "_") + '_' + str(trim) + '_vocals.wav'), + audio_voc[:, trim*args.sample_len*SR:(trim+1)*args.sample_len*SR].cpu(), + sample_rate=sr, + bits_per_sample=16) + else: + print("no file...") + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--saraga-dir', default=None, type=str) + parser.add_argument('--output-dir', default=None, type=str) + parser.add_argument('--sample-len', default=6) + parser.add_argument('--gpu', default=None) + args = parser.parse_args() + main(args) diff --git a/dataset/saraga.py b/dataset/saraga.py new file mode 100644 index 0000000..d2230c2 --- /dev/null +++ b/dataset/saraga.py @@ -0,0 +1,106 @@ +import os +import glob +import tensorflow as tf + +class SARAGA: + """Saraga dataset loader. + Use other opensource vocoder settings, 16bit, sr: 22050. + """ + SR = 22050 + + def __init__(self, config, data_dir=None): + """Initializer. + Args: + config: Config, dataset configuration. + data_dir: str, dataset directory + , defaults to '~/tensorflow_datasets'. + download: bool, download dataset or not. + from_tfds: bool, load from tfrecord generated by tfds or read raw audio. + """ + self.config = config + self.rawset = self.load_data('train', data_dir) + self.valset = self.load_data('val', data_dir) + + self.normalized = None + + def load_data(self, subset='train', data_dir=None): + """Load dataset from tfrecord or raw audio files. + Args: + data_dir: str, dataset directory. + Returns: + tf.data.Dataset, data loader. + """ + if subset == 'train': + mixture_files = glob.glob(os.path.join(data_dir, '*mixture.wav')) + for track in self.config.eval_tracks: + mixture_files = [x for x in mixture_files if track + '_' not in x] + else: + mixture_files = [] + for track in self.config.eval_tracks: + mixture_files += [x for x in glob.glob(os.path.join(data_dir, '*mixture.wav')) if (track + '_' in x) and \ + ('_' + track + '_' not in track)] + mixture_files = [x for x in mixture_files if 'silence' not in x] + files = tf.data.Dataset.from_tensor_slices( + [(mix, mix.replace('_mixture.', '_vocals.')) for mix in mixture_files]) + + return files.map(SARAGA._load_audio) + + @staticmethod + def _load_audio(paths): + """Load audio with tf apis. + Args: + path: str, wavfile path to read. + Returns: + tf.Tensor, [T], mono audio in range (-1, 1). + """ + mixture_audio, _ = tf.audio.decode_wav(tf.io.read_file(paths[0]), desired_channels=1) + vocal_audio, _ = tf.audio.decode_wav(tf.io.read_file(paths[1]), desired_channels=1) + return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1) + + def normalizer(self, frames=16000): + """Create dataset normalizer, make fixed size segment in range(-1, 1). + Args: + frames: int, segment size, frame unit. + from_tfds: bool, whether use tfds tfrecord or raw audio. + Returns: + Callable, normalizer. + """ + def normalize(mixture_signal, vocal_signal): #, accomp_signal): + """Normalize datum. + Args: + mixture_signal: tf.Tensor, [T], mono audio in range (-1, 1). + vocal_signal: tf.Tensor, [T], mono audio in range (-1, 1). + accomp_signal: tf.Tensor, [T], mono audio in range (-1, 1). + Returns: + tf.Tensor, [frames], fixed size mixture signal in range (-1, 1). + tf.Tensor, [frames], fixed size vocal signal in range (-1, 1). + tf.Tensor, [frames], fixed size accomp signal in range (-1, 1). + """ + nonlocal frames + frames = frames // self.config.hop * self.config.hop + start = tf.random.uniform( + (), 0, tf.shape(vocal_signal)[0] - frames, dtype=tf.int32) + return mixture_signal[start:start + frames], vocal_signal[start:start + frames] + + def dataset(self): + """Generate dataset. + """ + if self.normalized is None: + self.normalized = self.rawset \ + .map(self.normalizer(self.config.frames)) \ + .batch(self.config.batch) + return self.normalized + + def test_dataset(self): + """Generate dataset. + """ + return self.valset \ + .map(self.normalizer(self.config.frames)) \ + .batch(self.config.batch) + + def validation(self): + """Generate dataset. + """ + return self.valset \ + .map(self.normalizer(self.config.frames)) \ + .batch(1) diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..7677ac6 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,145 @@ +import numpy as np +import tensorflow as tf + +from .unet import UNet + + +class DiffWave(tf.keras.Model): + """Code copied and modified from DiffWave: A Versatile Diffusion Model for Audio Synthesis. + Zhifeng Kong et al., 2020. + """ + def __init__(self, config): + """Initializer. + Args: + config: Config, model configuration. + """ + super(DiffWave, self).__init__() + self.config = config + self.net = UNet(config) + + def call(self, signal, mode="predict", step_stop=0): + """Generate denoised audio. + Args: + signal: tf.Tensor, [B, T], starting signal for transformation. + Returns: + signal: tf.Tensor, [B, T], predicted output. + """ + base = tf.ones([tf.shape(signal)[0]], dtype=tf.int32) + if mode=="train": + features = [] + for t in range(self.config.iter, step_stop, -1): + signal = self.pred_noise(signal, base * t) + if mode == "train": + features.append(signal) + if mode == "train": + return tf.convert_to_tensor(features) + else: + return signal + + def diffusion(self, mixture, vocal, alpha_bar): + """Compute conditions + """ + diffusion_step = lambda x : self._diffusion(x[0], x[1], x[2]) + return tf.map_fn( + fn=diffusion_step, elems=[mixture, vocal, alpha_bar], + fn_output_signature=(tf.float32)) + + def _diffusion(self, mixture, vocals, alpha_bar): + """Trans to next state with diffusion process. + Args: + signal: tf.Tensor, [B, T], signal. + alpha_bar: Union[float, tf.Tensor: [B]], cumprod(1 -beta). + eps: Optional[tf.Tensor: [B, T]], noise. + Return: + tuple, + noised: tf.Tensor, [B, T], noised signal. + eps: tf.Tensor, [B, T], noise. + """ + mix_mag = self.check_shape( + self.check_shape( + tf.abs(tf.signal.stft( + mixture, + frame_length=self.config.win, + frame_step=self.config.hop, + fft_length=self.config.win, + window_fn=tf.signal.hann_window)), 0), 1) + #print(mix_mag.shape) + vocal_mag = self.check_shape( + self.check_shape( + tf.abs(tf.signal.stft( + vocals, + frame_length=self.config.win, + frame_step=self.config.hop, + fft_length=self.config.win, + window_fn=tf.signal.hann_window)), 0), 1) + return tf.dtypes.cast(alpha_bar, tf.float32) * vocal_mag + \ + tf.dtypes.cast(1 - tf.sqrt(alpha_bar), tf.float32) * mix_mag + + @staticmethod + def check_shape(data, dim): + n = data.shape[dim] + if n % 2 != 0: + n = data.shape[dim] - 1 + if dim==0: + return data[:n, :] + else: + return data[:, :n] + + def pred_noise(self, signal, timestep): + """Predict noise from signal. + Args: + signal: tf.Tensor, [B, T], noised signal. + timestep: tf.Tensor, [B], timesteps of current markov chain. + mel: tf.Tensor, [B, T // hop, M], conditional mel-spectrogram. + Returns: + tf.Tensor, [B, T], predicted noise. + """ + return self.net(signal, timestep) + + def pred_signal(self, signal, eps, alpha, alpha_bar): + """Compute mean and stddev of denoised signal. + Args: + signal: tf.Tensor, [B, T], noised signal. + eps: tf.Tensor, [B, T], estimated noise. + alpha: float, 1 - beta. + alpha_bar: float, cumprod(1 - beta). + Returns: + tuple, + mean: tf.Tensor, [B, T], estimated mean of denoised signal. + """ + signal = tf.dtypes.cast(signal, tf.float64) + eps = tf.dtypes.cast(eps, tf.float64) + + # Compute mean (our estimation) using diffusion formulation + mean = (signal - (1 - alpha) / tf.dtypes.cast(tf.sqrt(1 - alpha_bar), tf.float64) * eps) / tf.dtypes.cast(tf.sqrt(alpha), tf.float64) + stddev = np.sqrt((1 - alpha_bar / alpha) / (1 - alpha_bar) * (1 - alpha)) + return mean, stddev + + def write(self, path, optim=None): + """Write checkpoint with `tf.train.Checkpoint`. + Args: + path: str, path to write. + optim: Optional[tf.keras.optimizers.Optimizer] + , optional optimizer. + """ + kwargs = {'model': self} + if optim is not None: + kwargs['optim'] = optim + ckpt = tf.train.Checkpoint(**kwargs) + ckpt.save(path) + + def restore(self, path, optim=None): + """Restore checkpoint with `tf.train.Checkpoint`. + Args: + path: str, path to restore. + optim: Optional[tf.keras.optimizers.Optimizer] + , optional optimizer. + """ + kwargs = {'model': self} + if optim is not None: + kwargs['optim'] = optim + ckpt = tf.train.Checkpoint(**kwargs) + return ckpt.restore(path) + + + diff --git a/model/clustering.py b/model/clustering.py new file mode 100644 index 0000000..20bf6ce --- /dev/null +++ b/model/clustering.py @@ -0,0 +1,20 @@ +import numpy as np +import tensorflow as tf +from sklearn.cluster import KMeans + +def get_mask(normalized_feat, clusters, scheduler): + kmeans = KMeans(n_clusters=clusters, random_state=0).fit(normalized_feat) + centers = kmeans.cluster_centers_ + original_means = np.mean(centers, axis=1) + ordered_means = np.sort(np.mean(centers, axis=1)) + means_and_pos = {} + manual_weights = np.linspace(0, 1, clusters)**scheduler + for idx, j in zip(manual_weights, ordered_means): + means_and_pos[j] = idx + label_and_dist = [] + for j in original_means: + label_and_dist.append(means_and_pos[j]) + weights = [] + for j in kmeans.labels_: + weights.append(label_and_dist[j]) + return tf.math.divide(weights, float(clusters-1)) \ No newline at end of file diff --git a/model/config.py b/model/config.py new file mode 100644 index 0000000..cfb2936 --- /dev/null +++ b/model/config.py @@ -0,0 +1,86 @@ +import numpy as np +import tensorflow as tf + +class Config: + """Configuration for DiffWave implementation. + """ + def __init__(self): + self.model_type = None + + self.sr = 22050 + + self.hop = 256 + self.win = 1024 + + # mel-scale filter bank + self.mel = 80 + self.fmin = 0 + self.fmax = 8000 + + self.eps = 1e-5 + + # sample size + self.frames = (self.hop + 6) * 128 # 16384 + self.batch = 8 + + # leaky relu coefficient + self.leak = 0.4 + + # embdding config + self.embedding_size = 128 + self.embedding_proj = 512 + self.embedding_layers = 2 + self.embedding_factor = 4 + + # upsampler config + self.upsample_stride = [4, 1] + self.upsample_kernel = [32, 3] + self.upsample_layers = 4 + # computed hop size + # block config + self.channels = 64 + self.kernel_size = 3 + self.dilation_rate = 2 + self.num_layers = 30 + self.num_cycles = 3 + + # noise schedule + self.iter = 8 # 20, 40, 50 + self.noise_policy = 'linear' + self.noise_start = 1e-4 + self.noise_end = 0.5 # 0.02 for 200 + + def beta(self): + """Generate beta-sequence. + Returns: + List[float], [iter], beta values. + """ + mapper = { + 'linear': self._linear_sched, + } + if self.noise_policy not in mapper: + raise ValueError('invalid beta policy') + return mapper[self.noise_policy]() + + def _linear_sched(self): + """Linearly generated noise. + Returns: + List[float], [iter], beta values. + """ + return np.linspace( + self.noise_start, self.noise_end, self.iter, dtype=np.float32) + + def window_fn(self): + """Return window generator. + Returns: + Callable, window function of tf.signal + , which corresponds to self.win_fn. + """ + mapper = { + 'hann': tf.signal.hann_window, + 'hamming': tf.signal.hamming_window + } + if self.win_fn in mapper: + return mapper[self.win_fn] + + raise ValueError('invalid window function: ' + self.win_fn) diff --git a/model/estnoise_ms.py b/model/estnoise_ms.py new file mode 100644 index 0000000..8f6fc11 --- /dev/null +++ b/model/estnoise_ms.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue May 1 20:43:28 2018 +@author: eesungkim +""" + +import numpy as np +from scipy.special import jv + +def bessel(v, X): + return ((1j**(-v))*jv(v,1j*X)).real + +def stft(x, n_fft=512, win_length=400, hop_length=160, window='hamming'): + if window == 'hanning': + window = np.hanning(win_length) + elif window == 'hamming': + window = np.hamming(win_length) + elif window == 'rectangle': + window = np.ones(win_length) + return np.array([np.fft.rfft(window*x[i:i+win_length],n_fft,axis=0) for i in range(0, len(x)-win_length, hop_length)]) + +def estnoisem(pSpectrum,hop_length): + """ + This is python implementation of [1],[2], and [3]. + + Refs: + [1] Rainer Martin. + Noise power spectral density estimation based on optimal smoothing and minimum statistics. + IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. + [2] Rainer Martin. + Bias compensation methods for minimum statistics noise power spectral density estimation + Signal Processing, 2006, 86, 1215-1229 + [3] Dirk Mauler and Rainer Martin + Noise power spectral density estimation on highly correlated data + Proc IWAENC, 2006 + + Copyright (C) Mike Brookes 2008 + Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ + + VOICEBOX is a MATLAB toolbox for speech processing. + Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html + """ + + (nFrames,nFFT2)=np.shape(pSpectrum) # number of frames and freq bins + x=np.array(np.zeros((nFrames,nFFT2)) ) # initialize output arrays + xs=np.array(np.zeros((nFrames,nFFT2)) ) # will hold std error in the future + + # default algorithm constants + taca= 0.0449 # smoothing time constant for alpha_c = -hop_length/log(0.7) in equ (11) + tamax= 0.392 # max smoothing time constant in (3) = -hop_length/log(0.96) + taminh= 0.0133 # min smoothing time constant (upper limit) in (3) = -hop_length/log(0.3) + tpfall= 0.064 # time constant for P to fall (12) + tbmax= 0.0717 # max smoothing time constant in (20) = -hop_length/log(0.8) + qeqmin= 2.0 # minimum value of Qeq (23) + qeqmax= 14.0 # max value of Qeq per frame + av= 2.12 # fudge factor for bc calculation (23 + 13 lines) + td= 1.536 # time to take minimum over + nu= 8 # number of subwindows + qith= np.array([0.03, 0.05, 0.06, np.Inf],dtype=float) # noise slope thresholds in dB/s + nsmdb= np.array([47, 31.4, 15.7, 4.1],dtype=float) # maximum permitted +ve noise slope in dB/s + + + # derived algorithm constants + aca=np.exp(-hop_length/taca) # smoothing constant for alpha_c in equ (11) = 0.7 + acmax=aca # min value of alpha_c = 0.7 in equ (11) also = 0.7 + amax=np.exp(-hop_length/tamax) # max smoothing constant in (3) = 0.96 + aminh=np.exp(-hop_length/taminh) # min smoothing constant (upper limit) in (3) = 0.3 + bmax=np.exp(-hop_length/tbmax) # max smoothing constant in (20) = 0.8 + SNRexp = -hop_length/tpfall + nv=round(td/(hop_length*nu)) # length of each subwindow in frames + + + if nv<4: # algorithm doesn't work for miniscule frames + nv=4 + nu=round(td/(hop_length*nv)) + nd=nu*nv # length of total window in frames + (md,hd,dd) = mhvals(nd) # calculate the constants M(D) and H(D) from Table III + (mv,hv,dv) = mhvals(nv) # calculate the constants M(D) and H(D) from Table III + nsms=np.array([10])**(nsmdb*nv*hop_length/10) # [8 4 2 1.2] in paper + qeqimax=1/qeqmin # maximum value of Qeq inverse (23) + qeqimin=1/qeqmax # minumum value of Qeq per frame inverse + + + p=pSpectrum[0,:] # smoothed power spectrum + ac=1 # correction factor (9) + sn2=p # estimated noise power + pb=p # smoothed noisy speech power (20) + pb2=pb**2 + pminu=p + actmin=np.array(np.ones(nFFT2) * np.Inf) # Running minimum estimate + actminsub=np.array(np.ones(nFFT2) * np.Inf) # sub-window minimum estimate + subwc=nv # force a buffer switch on first loop + actbuf=np.array(np.ones((nu,nFFT2)) * np.Inf) # buffer to store subwindow minima + ibuf=0 + lminflag=np.zeros(nFFT2) # flag to remember local minimum + + # loop for each frame + for t in range(0,nFrames): # we use t instead of lambda in the paper + pSpectrum_t=pSpectrum[t,:] # noise speech power spectrum + acb=(1+(sum(p) / sum(pSpectrum_t)-1)**2)**(-1) # alpha_c-bar(t) (9) + + tmp=np.array([acb] ) + tmp[tmp < acmax] = acmax + #max_complex(np.array([acb] ),np.array([acmax] )) + + ac=aca*ac+(1-aca)*tmp # alpha_c(t) (10) + + ah=amax*ac*(1+(p/sn2-1)**2)**(-1) # alpha_hat: smoothing factor per frequency (11) + SNR=sum(p)/sum(sn2) + + + ah=max_complex(ah,min_complex(np.array([aminh] ),np.array([SNR**SNRexp] ))) # lower limit for alpha_hat (12) + + p=ah*p+(1-ah)*pSpectrum_t # smoothed noisy speech power (3) + + b=min_complex(ah**2,np.array([bmax] )) # smoothing constant for estimating periodogram variance (22 + 2 lines) + pb=b*pb + (1-b)*p # smoothed periodogram (20) + pb2=b*pb2 + (1-b)*p**2 # smoothed periodogram squared (21) + + qeqi=max_complex(min_complex((pb2-pb**2)/(2*sn2**2),np.array([qeqimax] )),np.array([qeqimin/(t+1)] )) # Qeq inverse (23) + qiav=sum(qeqi)/nFFT2 # Average over all frequencies (23+12 lines) (ignore non-duplication of DC and nyquist terms) + bc=1+av*np.sqrt(qiav) # bias correction factor (23+11 lines) + bmind=1+2*(nd-1)*(1-md)/(qeqi**(-1)-2*md) # we use the signalmplified form (17) instead of (15) + bminv=1+2*(nv-1)*(1-mv)/(qeqi**(-1)-2*mv) # same expressignalon but for sub windows + kmod=(bc*p*bmind) < actmin # Frequency mask for new minimum + + if any(kmod): + actmin[kmod]=bc*p[kmod]*bmind[kmod] + actminsub[kmod]=bc*p[kmod]*bminv[kmod] + + if subwc>1 and subwc=nv: # end of buffer - do a buffer switch + ibuf=1+(ibuf%nu) # increment actbuf storage pointer + actbuf[ibuf-1,:]=actmin.copy() # save sub-window minimum + pminu=min_complex_mat(actbuf) + i=np.nonzero(np.array(qiav )pminu) + if any(lmin): + pminu[lmin]=actminsub[lmin] + actbuf[:,lmin]= np.ones((nu,1)) * pminu[lmin] + lminflag[:]=0 + actmin[:]=np.Inf + subwc=0 + + subwc=subwc+1 + x[t,:]=sn2.copy() + qisq=np.sqrt(qeqi) + # empirical formula for standard error based on Fig 15 of [2] + xs[t,:]=sn2*np.sqrt(0.266*(nd+100*qisq)*qisq/(1+0.005*nd+6/nd)/(0.5*qeqi**(-1)+nd-1)) + + + return x + +def mhvals(*args): + """ + This is python implementation of [1],[2], and [3]. + + Refs: + [1] Rainer Martin. + Noise power spectral density estimation based on optimal smoothing and minimum statistics. + IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. + [2] Rainer Martin. + Bias compensation methods for minimum statistics noise power spectral density estimation + Signal Processing, 2006, 86, 1215-1229 + [3] Dirk Mauler and Rainer Martin + Noise power spectral density estimation on highly correlated data + Proc IWAENC, 2006 + + Copyright (C) Mike Brookes 2008 + Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ + + VOICEBOX is a MATLAB toolbox for speech processing. + Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html + """ + nargin = len(args) + + dmh=np.array([ + [1, 0, 0], + [2, 0.26, 0.15], + [5, 0.48, 0.48], + [8, 0.58, 0.78], + [10, 0.61, 0.98], + [15, 0.668, 1.55], + [20, 0.705, 2], + [30, 0.762, 2.3], + [40, 0.8, 2.52], + [60, 0.841, 3.1], + [80, 0.865, 3.38], + [120, 0.89, 4.15], + [140, 0.9, 4.35], + [160, 0.91, 4.25], + [180, 0.92, 3.9], + [220, 0.93, 4.1], + [260, 0.935, 4.7], + [300, 0.94, 5] + ],dtype=float) + + if nargin>=1: + d=args[0] + i=np.nonzero(d<=dmh[:,0]) + if len(i)==0: + i=np.shape(dmh)[0]-1 + j=i + else: + i=i[0][0] + j=i-1 + if d==dmh[i,0]: + m=dmh[i,1] + h=dmh[i,2] + else: + qj=np.sqrt(dmh[i-1,0]) # interpolate usignalng sqrt(d) + qi=np.sqrt(dmh[i,0]) + q=np.sqrt(d) + h=dmh[i,2]+(q-qi)*(dmh[j,2]-dmh[i,2])/(qj-qi) + m=dmh[i,1]+(qi*qj/q-qj)*(dmh[j,1]-dmh[i,1])/(qi-qj) + else: + d=dmh[:,0].copy() + m=dmh[:,1].copy() + h=dmh[:,2].copy() + + return m,h,d + + +def max_complex(a,b): + """ + This is python implementation of [1],[2], and [3]. + + Refs: + [1] Rainer Martin. + Noise power spectral density estimation based on optimal smoothing and minimum statistics. + IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. + [2] Rainer Martin. + Bias compensation methods for minimum statistics noise power spectral density estimation + Signal Processing, 2006, 86, 1215-1229 + [3] Dirk Mauler and Rainer Martin + Noise power spectral density estimation on highly correlated data + Proc IWAENC, 2006 + + Copyright (C) Mike Brookes 2008 + Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ + + VOICEBOX is a MATLAB toolbox for speech processing. + Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html + """ + if len(a)==1 and len(b)>1: + a=np.tile(a,np.shape(b)) + if len(b)==1 and len(a)>1: + b=np.tile(b,np.shape(a)) + + i=np.logical_or(np.iscomplex(a),np.iscomplex(b)) + + aa = a.copy() + bb = b.copy() + + if any(i): + aa[i]=np.absolute(aa[i]) + bb[i]=np.absolute(bb[i]) + if a.dtype == 'complex' or b.dtype== 'complex': + cc = np.array(np.zeros(np.shape(a)) ) + else: + cc = np.array(np.zeros(np.shape(a)),dtype=float) + + i=aa>bb + cc[i]=a[i] + cc[np.logical_not(i)] = b[np.logical_not(i)] + + return cc + +def min_complex(a,b): + """ + This is python implementation of [1],[2], and [3]. + + Refs: + [1] Rainer Martin. + Noise power spectral density estimation based on optimal smoothing and minimum statistics. + IEEE Trans. Speech and Audio Processing, 9(5):504-512, July 2001. + [2] Rainer Martin. + Bias compensation methods for minimum statistics noise power spectral density estimation + Signal Processing, 2006, 86, 1215-1229 + [3] Dirk Mauler and Rainer Martin + Noise power spectral density estimation on highly correlated data + Proc IWAENC, 2006 + + Copyright (C) Mike Brookes 2008 + Version: $Id: estnoisem.m 1718 2012-03-31 16:40:41Z dmb $ + + VOICEBOX is a MATLAB toolbox for speech processing. + Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html + """ + if len(a)==1 and len(b)>1: + a=np.tile(a,np.shape(b)) + if len(b)==1 and len(a)>1: + b=np.tile(b,np.shape(a)) + + i=np.logical_or(np.iscomplex(a),np.iscomplex(b)) + + aa = a.copy() + bb = b.copy() + + if any(i): + aa[i]=np.absolute(aa[i]) + bb[i]=np.absolute(bb[i]) + + if a.dtype == 'complex' or b.dtype== 'complex': + cc = np.array(np.zeros(np.shape(a)) ) + else: + cc = np.array(np.zeros(np.shape(a)),dtype=float) + + i=aa self.num_res_blocks: + for extra_lay in block[self.num_res_blocks:]: + hs.append(extra_lay(hs[-1])) + + # Middle. + h = hs[-1] + for _, lay in enumerate(self.middle): + h = lay(h, temb) + + # Upsampling. + for block in self.upsampling: + # Residual blocks for this resolution. + for idx_block in range(self.num_res_blocks+1): + if isinstance(block[idx_block], list): + if cond is not None: + h = block[idx_block][0](tf.concat([h, hs.pop()], axis=-1), temb, cond) + else: + h = block[idx_block][0](tf.concat([h, hs.pop()], axis=-1), temb) + h = block[idx_block][1](h) + else: + if cond is not None: + h = block[idx_block](tf.concat([h, hs.pop()], axis=-1), temb, cond) + else: + h = block[idx_block](tf.concat([h, hs.pop()], axis=-1), temb) + # Upsample. + if len(block) > self.num_res_blocks+1: + for extra_lay in block[self.num_res_blocks+1:]: + h = extra_lay(h) + + # End. + for lay in self.end: + h = lay(h) + + h = tf.keras.activations.sigmoid(h) + h = tf.squeeze(h, axis=-1) + + return tf.multiply(inputs, h) + diff --git a/model/unet_utils.py b/model/unet_utils.py new file mode 100644 index 0000000..1621308 --- /dev/null +++ b/model/unet_utils.py @@ -0,0 +1,40 @@ +# nn.py +# Source: https://github.com/hojonathanho/diffusion/blob/master/ +# diffusion_tf/nn.py +# Tensorflow 2.4.0 +# Windows/MacOS/Linux +# Python 3.7 + + +import math +import tensorflow as tf + + +def default_init(scale): + return tf.initializers.variance_scaling( + scale=1e-10 if scale == 0 else scale, + mode="fan_avg", + distribution="uniform") + + +def meanflat(x): + return tf.math.reduce_mean(x, axis=list(range(1, len(x.shape)))) + + +def get_timestep_embedding(timesteps, embedding_dim): + # From fairseq. Build sinusoidal embeddings. This matches the + # implementation in tensor2tensor, but differs slightly from the + # description in Section 3.5 of "Attention Is All You Need". + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = tf.math.exp(tf.range(half_dim, dtype=tf.float32) * -emb) + # emb = tf.range(num_embeddings, dtype=tf.float32)[:, None] * emb[None, :] + emb = tf.cast(timesteps, dtype=tf.float32)[:, None] * emb[None, :] + emb = tf.concat([tf.math.sin(emb), tf.math.cos(emb)], axis=1) + if embedding_dim % 2 == 1: # zero pad. + # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1) + emb = tf.pad(emb, [[0, 0], [0, 1]]) + assert emb.shape == [timesteps.shape[0], embedding_dim] + return emb \ No newline at end of file diff --git a/model/vad.py b/model/vad.py new file mode 100644 index 0000000..b0c1956 --- /dev/null +++ b/model/vad.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue May 1 20:43:28 2018 +@author: eesungkim +""" + +import math +import numpy as np +from model.estnoise_ms import * + + +def VAD(signal, sr, nFFT=512, win_length=0.025, hop_length=0.01, theshold=0.7): + """Voice Activity Detector + Parameters + ---------- + signal : audio time series + sr : sampling rate of `signal` + nFFT : length of the FFT window + win_length : window size in sec + hop_length : hop size in sec + + Returns + ------- + probRatio : frame-based voice activity probability sequence + """ + signal=signal.astype('float') + + maxPosteriorSNR= 1000 + minPosteriorSNR= 0.0001 + + win_length_sample = round(win_length*sr) + hop_length_sample = round(hop_length*sr) + + # the variance of the speech; lambda_x(k) + _stft = stft(signal, n_fft=nFFT, win_length=win_length_sample, hop_length=hop_length_sample) + pSpectrum = np.abs(_stft) ** 2 + + # estimate the variance of the noise using minimum statistics noise PSD estimation ; lambda_d(k). + estNoise = estnoisem(pSpectrum,hop_length) + estNoise = estNoise + + aPosterioriSNR=pSpectrum/estNoise + aPosterioriSNR=aPosterioriSNR + aPosterioriSNR[aPosterioriSNR > maxPosteriorSNR] = maxPosteriorSNR + aPosterioriSNR[aPosterioriSNR < minPosteriorSNR] = minPosteriorSNR + + a01=hop_length/0.05 # a01=P(signallence->speech) hop_length/mean signallence length (50 ms) + a00=1-a01 # a00=P(signallence->signallence) + a10=hop_length/0.1 # a10=P(speech->signallence) hop/mean talkspurt length (100 ms) + a11=1-a10 # a11=P(speech->speech) + + b01=a01/a00 + b10=a11-a10*a01/a00 + + smoothFactorDD=0.99 + previousGainedaPosSNR=1 + (nFrames,nFFT2) = pSpectrum.shape + probRatio=np.zeros((nFrames,1)) + logGamma_frame=0 + for i in range(nFrames): + aPosterioriSNR_frame = aPosterioriSNR[i,:] + + #operator [2](52) + oper=aPosterioriSNR_frame-1 + oper[oper < 0] = 0 + smoothed_a_priori_SNR = smoothFactorDD * previousGainedaPosSNR + (1-smoothFactorDD) * oper + + #V for MMSE estimate ([2](8)) + V=0.1*smoothed_a_priori_SNR*aPosterioriSNR_frame/(1+smoothed_a_priori_SNR) + + #geometric mean of log likelihood ratios for individual frequency band [1](4) + logLRforFreqBins=2*V-np.log(smoothed_a_priori_SNR+1) + # logLRforFreqBins=np.exp(smoothed_a_priori_SNR*aPosterioriSNR_frame/(1+smoothed_a_priori_SNR))/(1+smoothed_a_priori_SNR) + gMeanLogLRT=np.mean(logLRforFreqBins) + logGamma_frame=np.log(a10/a01) + gMeanLogLRT + np.log(b01+b10/( a10+a00*np.exp(-logGamma_frame) ) ) + probRatio[i]=1/(1+np.exp(-logGamma_frame)) + + #Calculate Gain function which results from the MMSE [2](7). + gain = (math.gamma(1.5) * np.sqrt(V)) / aPosterioriSNR_frame * np.exp(-1 * V / 2) * ((1 + V) * bessel(0, V / 2) + V * bessel(1, V / 2)) + + previousGainedaPosSNR = (gain**2) * aPosterioriSNR_frame + probRatio[probRatio>theshold]=1 + probRatio[probRatio=2.1.0 +tqdm==4.48.2 \ No newline at end of file diff --git a/separate.py b/separate.py new file mode 100644 index 0000000..69b55f2 --- /dev/null +++ b/separate.py @@ -0,0 +1,115 @@ +import os +import tqdm +import json +import math +import argparse + +import numpy as np +import soundfile as sf +import tensorflow as tf + +from scipy.signal import get_window + +from config import Config as UnetConfig +from model import DiffWave +from model.vad import VAD +from model.clustering import get_mask +from utils.signal_processing import ( + compute_stft, + compute_signal_from_stft, + next_power_of_2 +) + + +def main(args): + + # Activate CUDA if GPU id is given + if args.gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + model_path = os.path.join(".", "ckpt", args.model_name, args.model_name+"_BEST-SDR.ckpt-1") + with open(os.path.join(".", "ckpt", args.model_name + ".json")) as f: + unet_config = UnetConfig.load(json.load(f)) + + diffwave = DiffWave(unet_config.model) + diffwave.restore(model_path).expect_partial() + mixture = tf.io.read_file(args.input_signal) + mixture, _ = tf.audio.decode_wav(mixture, desired_channels=1) + mixture = tf.squeeze(mixture, axis=-1) / tf.reduce_max(mixture) + + TRIMS = args.batch_size + output_voc = np.zeros(mixture.shape) + hopsized_batch = int((TRIMS*22050) / 2) + runs = math.floor(mixture.shape[0] / hopsized_batch) + trim_low = 0 + for trim in tqdm.tqdm(np.arange((runs*2)-1)): + trim_high = int(trim_low + (hopsized_batch*2)) + + # Get input mixture spectrogram + mix_trim = mixture[trim_low:trim_high] + mix_mag, mix_phase = compute_stft(mix_trim[None], unet_config) + new_len = next_power_of_2(mix_mag.shape[1]) + mix_mag_trim = mix_mag[:, :new_len, :] + mix_phase_trim = mix_phase[:, :new_len, :] + + # Get and stack cold diffusion steps + diff_feat = diffwave(mix_mag_trim, mode="train") + diff_feat = tf.transpose(diff_feat, [1, 0, 2, 3]) + diff_feat_t = tf.squeeze(tf.reshape(diff_feat, [1, 8, diff_feat.shape[-2]*diff_feat.shape[-1]]), axis=0).numpy() + + # Normalize features, all energy curves having same range + normalized_feat = [] + for j in np.arange(diff_feat_t.shape[1]): + normalized_curve = diff_feat_t[:, j] / np.max(np.abs(diff_feat_t[:, j])) + normalized_feat.append(normalized_curve) + normalized_feat = np.array(normalized_feat, dtype=np.float32) + + # Compute mask using unsupervised clustering and reshape to magnitude spec shape + mask = get_mask(normalized_feat, args.clusters, args.scheduler) + mask = tf.reshape(mask, mix_mag_trim.shape) + + # Getting last step of computed features and applying mask + diff_feat_t = tf.reshape(diff_feat_t[-1, :], mix_mag_trim.shape) + output_signal = tf.math.multiply(diff_feat_t, mask) + + # Silence unvoiced regions + output_signal = compute_signal_from_stft(output_signal, mix_phase_trim, unet_config) + pred_audio = tf.squeeze(output_signal, axis=0).numpy() + vad = VAD(pred_audio, sr=22050, nFFT=512, win_length=0.025, hop_length=0.01, theshold=0.99) + if np.sum(vad) / len(vad) < 0.25: + pred_audio = np.zeros(pred_audio.shape) + + # Get boundary + boundary = None + boundary = "start" if trim == 0 else None + boundary = "end" if trim == runs-2 else None + + placehold_voc = np.zeros(output_voc.shape) + placehold_voc[trim_low:trim_low+pred_audio.shape[0]] = pred_audio * get_window(pred_audio, boundary=boundary) + output_voc += placehold_voc + trim_low += pred_audio.shape[0] // 2 + + output_voc = output_voc * (np.max(np.abs(mixture.numpy())) / np.max(np.abs(output_voc))) + + # Building intuitive filename with model config + filefolder = os.path.join(args.input_signal.split("/")[:-1]) + filename = args.input_signal.split("/")[-1].split(".")[:-1] + filename = filename[0] if len(filename) == 1 else ".".join(filename) + filename = filename + "_" + str(args.clusters) + "_" + str(args.scheduler) + "pred_voc" + sf.write( + os.path.join(filefolder, filename + ".wav"), + output_voc, + 22050) # Writing to file + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model-name', default='saraga-8') + parser.add_argument('--input_signal', default=None, type=str) + parser.add_argument('--batch-size', default=6) + parser.add_argument('--clusters', default=4, type=int) + parser.add_argument('--scheduler', default=3., type=float) + parser.add_argument('--gpu', default=None) + args = parser.parse_args() + main(args) diff --git a/train.py b/train.py new file mode 100644 index 0000000..030036f --- /dev/null +++ b/train.py @@ -0,0 +1,276 @@ +import os +import argparse +import json +import tqdm + +import numpy as np +import tensorflow as tf + +from config import Config as UnetConfig +from dataset import SARAGA +from model import DiffWave + +from utils.separation_eval import GlobalSDR +from utils.signal_processing import check_shape_3d + +import warnings +warnings.filterwarnings('ignore') + +epsilon = 1e-6 + +class Trainer: + """WaveGrad trainer. + """ + def __init__(self, model, saraga, config, data_dir): + """Initializer. + Args: + model: DiffWave, diffwave model. + saraga: Saraga, saraga dataset + which provides already batched and normalized speech dataset. + config: Config, unified configurations. + """ + self.model = model + self.saraga = saraga + self.config = config + self.data_dir = data_dir + + self.split = config.train.split // config.data.batch + self.trainset = self.saraga.dataset().take(self.split) \ + .shuffle(config.train.bufsiz) \ + .prefetch(tf.data.experimental.AUTOTUNE) + self.testset = self.saraga.test_dataset() \ + .prefetch(tf.data.experimental.AUTOTUNE) + + self.optim = tf.keras.optimizers.Adam( + config.train.lr(), + config.train.beta1, + config.train.beta2, + config.train.eps) + + self.eval_intval = config.train.eval_intval // config.data.batch + self.ckpt_intval = config.train.ckpt_intval // config.data.batch + + self.train_log = tf.summary.create_file_writer( + os.path.join(config.train.log, config.train.name, 'train')) + self.test_log = tf.summary.create_file_writer( + os.path.join(config.train.log, config.train.name, 'test')) + + self.ckpt_path = os.path.join( + config.train.ckpt, config.train.name, config.train.name) + + self.alpha_bar = np.linspace(1, 0, config.model.iter + 1) + + @staticmethod + def tf_log10(x): + numerator = tf.math.log(x) + denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) + return numerator / denominator + + def compute_loss(self, mixture, vocals): #, accomp): + """Compute loss for noise estimation. + Args: + signal: tf.Tensor, [B, T], raw audio signal mixture. + signal: tf.Tensor, [B, T], raw audio signal vocals. + Returns: + loss: tf.Tensor, [], L1-loss between noise and estimation. + """ + bsize = tf.shape(vocals)[0] + # [B] + timesteps = tf.random.uniform( + [bsize], 1, self.config.model.iter + 1, dtype=tf.int32) + + # [B] + noise_level = tf.gather(self.alpha_bar, timesteps) + noise_level_next = tf.gather(self.alpha_bar, timesteps - 1) + + # [B, T], [B, T] + noised = self.model.diffusion(mixture, vocals, noise_level) + noised_next = self.model.diffusion(mixture, vocals, noise_level_next) + # [B, T] + est = self.model.pred_noise(noised, timesteps) + # [] + l1_loss = tf.reduce_mean(tf.abs(est - noised_next)) + return l1_loss + + + def train(self, step=0): + """Train wavegrad. + Args: + step: int, starting step. + ir_unit: int, log ir units. + """ + count = 0 + best_SDR = 0 + best_step = 0 + less_loss = 1000 + less_train_loss = 1000 + # Start training + for _ in tqdm.trange(step // self.split, self.config.train.epoch): + train_loss = [] + with tqdm.tqdm(total=self.split, leave=False) as pbar: + for mixture, vocal in self.trainset: + with tf.GradientTape() as tape: + tape.watch(self.model.trainable_variables) + loss = self.compute_loss(mixture, vocal) + train_loss.append(loss) + + grad = tape.gradient(loss, self.model.trainable_variables) + self.optim.apply_gradients( + zip(grad, self.model.trainable_variables)) + + norm = tf.reduce_mean([tf.norm(g) for g in grad]) + del grad + + step += 1 + pbar.update() + pbar.set_postfix( + {'loss': loss.numpy().item(), + 'step': step, + 'grad': norm.numpy().item()}) + + if step % self.ckpt_intval == 0: + self.model.write( + '{}.ckpt'.format(self.ckpt_path), + self.optim) + + train_loss = sum(train_loss) / len(train_loss) + print('\nTrain loss:', str(round(train_loss.numpy(), 5))) + loss = [] + for mixture, vocal in self.testset: + actual_loss = self.compute_loss(mixture, vocal) + loss.append(actual_loss.numpy().item()) + + del mixture, vocal + loss = sum(loss) / len(loss) + print('Eval loss:', str(round(loss, 5))) + if loss <= less_loss: + if train_loss <= less_train_loss: + print('Saving best new model given loss values!') + self.model.write('{}_BEST-LOSS.ckpt'.format(self.ckpt_path),self.optim) + less_loss = loss + less_train_loss = train_loss + + with self.test_log.as_default(): + if count%1 == 0: + best_SDR, best_step = self.eval(best_SDR, best_step, step) + count += 1 + + + def eval(self, best_SDR, best_step, step): + """Generate evaluation purpose audio. + Returns: + speech: np.ndarray, [T], ground truth. + pred: np.ndarray, [T], predicted. + ir: List[np.ndarray], config.model.iter x [T], + intermediate representations. + """ + # [T] + voc_sdr = [] + for mixture, vocals in tqdm.tqdm(saraga.validation().take(300)): + if np.max(tf.squeeze(mixture, axis=0).numpy())>0: + if np.max(tf.squeeze(vocals, axis=0).numpy())>0: + mix_mag, _ = self.compute_stft(mixture) + _, voc_phase = self.compute_stft(vocals) + + pred = self.model(mix_mag) + pred = self.compute_signal_from_stft(pred, voc_phase) + mixture = mixture[:, :pred.shape[1]] + vocals = vocals[:, :pred.shape[1]] + pred = tf.transpose(pred, [1, 0]).numpy() + vocals = tf.transpose(vocals, [1, 0]).numpy() + + ref = np.array([vocals]) + est = np.array([pred]) + + scores = GlobalSDR(ref, est) + voc_sdr.append(scores[0]) + + print('Median SDR:', np.median(voc_sdr)) + print('Best model:', best_SDR) + if np.median(voc_sdr) > best_SDR: + print('Saving best new model with SDR:', np.median(voc_sdr)) + self.model.write('{}_BEST-SDR.ckpt'.format(self.ckpt_path),self.optim) + best_SDR = np.median(voc_sdr) + return best_SDR, best_step + + def compute_stft(self, signal): + signal_stft = check_shape_3d( + check_shape_3d( + tf.signal.stft( + signal, + frame_length=self.config.model.win, + frame_step=self.config.model.hop, + fft_length=self.config.model.win, + window_fn=tf.signal.hann_window), 1), 2) + mag = tf.abs(signal_stft) + phase = tf.math.angle(signal_stft) + return mag, phase + + def compute_signal_from_stft(self, spec, phase): + polar_spec = tf.complex(tf.multiply(spec, tf.math.cos(phase)), tf.zeros(spec.shape)) + \ + tf.multiply(tf.complex(spec, tf.zeros(spec.shape)), tf.complex(tf.zeros(phase.shape), tf.math.sin(phase))) + return tf.signal.inverse_stft( + polar_spec, + frame_length=self.config.model.win, + frame_step=self.config.model.hop, + window_fn=tf.signal.inverse_stft_window_fn( + self.config.model.hop, + forward_window_fn=tf.signal.hann_window)) + + @staticmethod + def load_audio(paths): + mixture = tf.io.read_file(paths[0]) + vocals = tf.io.read_file(paths[1]) + mixture_audio, _ = tf.audio.decode_wav(mixture, desired_channels=1) + vocal_audio, _ = tf.audio.decode_wav(vocals, desired_channels=1) + return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default=None) + parser.add_argument('--load-step', default=0, type=int) + parser.add_argument('--data-dir', default=None) + parser.add_argument('--gpu', default=None) + args = parser.parse_args() + + # Activate CUDA if GPU id is given + if args.gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + config = UnetConfig() + + if args.config is not None: + print('[*] load config: ' + args.config) + with open(args.config) as f: + config = UnetConfig.load(json.load(f)) + + log_path = os.path.join(config.train.log, config.train.name) + if not os.path.exists(log_path): + os.makedirs(log_path) + + ckpt_path = os.path.join(config.train.ckpt, config.train.name) + if not os.path.exists(ckpt_path): + os.makedirs(ckpt_path) + + sounds_path = os.path.join(config.train.sounds, config.train.name) + if not os.path.exists(sounds_path): + os.makedirs(sounds_path) + + saraga = SARAGA(config.data, data_dir=args.data_dir) + diffwave = DiffWave(config.model) + trainer = Trainer(diffwave, saraga, config, data_dir=args.data_dir) + + if args.load_step > 0: + super_path = os.path.join(config.train.ckpt, config.train.name) + ckpt_path = os.path.join(super_path, '{}.ckpt-1'.format(config.train.name)) + print('[*] load checkpoint: ' + ckpt_path) + trainer.model.restore(ckpt_path, trainer.optim) + print("Loaded!") + + with open(os.path.join(config.train.ckpt, config.train.name + '.json'), 'w') as f: + json.dump(config.dump(), f) + + trainer.train(args.load_step) diff --git a/utils/noam_schedule.py b/utils/noam_schedule.py new file mode 100644 index 0000000..4343d24 --- /dev/null +++ b/utils/noam_schedule.py @@ -0,0 +1,32 @@ +import tensorflow as tf + + +class NoamScheduler(tf.keras.optimizers.schedules.LearningRateSchedule): + """Noam learning rate scheduler from Vaswani et al., 2017. + """ + def __init__(self, learning_rate, warmup_steps, channels): + """Initializer. + Args: + learning_rate: float, initial learning rate. + warmup_steps: int, warmup steps. + channels: int, base hidden size of the model. + """ + super(NoamScheduler, self).__init__() + self.learning_rate = learning_rate + self.warmup_steps = warmup_steps + self.channels = channels + + def __call__(self, step): + """Compute learning rate. + """ + return self.learning_rate * self.channels ** -0.5 * \ + tf.minimum(step ** -0.5, step * self.warmup_steps ** -1.5) + + def get_config(self): + """Serialize configurations. + """ + return { + 'learning_rate': self.learning_rate, + 'warmup_steps': self.warmup_steps, + 'channels': self.channels, + } diff --git a/utils/phase_vocoder.py b/utils/phase_vocoder.py new file mode 100644 index 0000000..87101e5 --- /dev/null +++ b/utils/phase_vocoder.py @@ -0,0 +1,59 @@ +import numpy as np +import tensorflow as tf + + +def phase_vocoder(D, hop_len=256, rate=0.8): + """Phase vocoder. Given an STFT matrix D, speed up by a factor of `rate`. + Based on implementation provided by: + https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#phase_vocoder + :param D: tf.complex64([num_frames, num_bins]): the STFT tensor + :param hop_len: float: the hop length param of the STFT + :param rate: float > 0: the speed-up factor + :return: D_stretched: tf.complex64([num_frames, num_bins]): the stretched STFT tensor + """ + # get shape + sh = tf.shape(D, name="STFT_shape") + frames = sh[0] + fbins = sh[1] + + # time steps range + t = tf.range(0.0, tf.cast(frames, tf.float32), rate, dtype=tf.float32, name="time_steps") + + # Expected phase advance in each bin + dphi = tf.linspace(0.0, np.pi * hop_len, fbins, name="dphi_expected_phase_advance") + phase_acc = tf.math.angle(D[0, :], name="phase_acc_init") + + # Pad 0 columns to simplify boundary logic + D = tf.pad(D, [(0, 2), (0, 0)], mode='CONSTANT', name="padded_STFT") + + # def fn(previous_output, current_input): + def _pvoc_mag_and_cum_phase(previous_output, current_input): + # unpack prev phase + _, prev = previous_output + + # grab the two current columns of the STFT + i = tf.cast((tf.floor(current_input) + [0, 1]), tf.int32) + bcols = tf.gather_nd(D, [[i[0]], [i[1]]]) + + # Weighting for linear magnitude interpolation + t_dif = current_input - tf.floor(current_input) + bmag = (1 - t_dif) * tf.abs(bcols[0, :]) + t_dif * (tf.abs(bcols[1, :])) + + # Compute phase advance + dp = tf.math.angle(bcols[1, :]) - tf.math.angle(bcols[0, :]) - dphi + dp = dp - 2 * np.pi * tf.round(dp / (2.0 * np.pi)) + + # return linear mag, accumulated phase + return bmag, tf.squeeze(prev + dp + dphi) + + # initializer of zeros of correct shape for mag, and phase_acc for phase + initializer = (tf.zeros(fbins, tf.float32), phase_acc) + mag, phase = tf.scan(_pvoc_mag_and_cum_phase, t, initializer=initializer, + parallel_iterations=10, back_prop=False, + name="pvoc_cum_phase") + + # add the original phase_acc in + phase2 = tf.concat([tf.expand_dims(phase_acc, 0), phase], 0)[:-1, :] + D_stretched = tf.cast(mag, tf.complex64) * tf.exp(1.j * tf.cast(phase2, tf.complex64), name="stretched_STFT") + + return D_stretched \ No newline at end of file diff --git a/utils/separation_eval.py b/utils/separation_eval.py new file mode 100644 index 0000000..3d0b6f2 --- /dev/null +++ b/utils/separation_eval.py @@ -0,0 +1,10 @@ +import numpy as np + +def GlobalSDR(references, separations): + """ Global SDR """ + delta = 1e-7 # avoid numerical errors + num = np.sum(np.square(references), axis=(1, 2)) + den = np.sum(np.square(references - separations), axis=(1, 2)) + num += delta + den += delta + return 10 * np.log10(num / den) \ No newline at end of file diff --git a/utils/signal_processing.py b/utils/signal_processing.py new file mode 100644 index 0000000..bf4fc05 --- /dev/null +++ b/utils/signal_processing.py @@ -0,0 +1,74 @@ +import numpy as np +import tensorflow as tf + + +def get_window(signal, boundary=None): + window_out = np.ones(signal.shape) + midpoint = window_out.shape[0] // 2 + if boundary == "start": + window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) + elif boundary == "end": + window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) + else: + window_out[:midpoint] = np.linspace(0, 1, window_out.shape[0]-midpoint) + window_out[midpoint:] = np.linspace(1, 0, window_out.shape[0]-midpoint) + return window_out + + +def compute_stft(signal, unet_config): + signal_stft = check_shape_3d( + check_shape_3d( + tf.signal.stft( + signal, + frame_length=unet_config.model.win, + frame_step=unet_config.model.hop, + fft_length=unet_config.model.win, + window_fn=tf.signal.hann_window), 1), 2) + mag = tf.abs(signal_stft) + phase = tf.math.angle(signal_stft) + return mag, phase + + +def compute_signal_from_stft(spec, phase, config): + polar_spec = tf.complex(tf.multiply(spec, tf.math.cos(phase)), tf.zeros(spec.shape)) + \ + tf.multiply(tf.complex(spec, tf.zeros(spec.shape)), tf.complex(tf.zeros(phase.shape), tf.math.sin(phase))) + return tf.signal.inverse_stft( + polar_spec, + frame_length=config.model.win, + frame_step=config.model.hop, + window_fn=tf.signal.inverse_stft_window_fn( + config.model.hop, + forward_window_fn=tf.signal.hann_window)) + + +def log2(x, base): + return int(np.log(x) / np.log(base)) + + +def next_power_of_2(n): + # decrement `n` (to handle the case when `n` itself is a power of 2) + n = n - 1 + # calculate the position of the last set bit of `n` + lg = log2(n, 2) + # next power of two will have a bit set at position `lg+1`. + return 1 << lg #+ 1 + + +def check_shape_3d(data, dim): + n = data.shape[dim] + if n % 2 != 0: + n = data.shape[dim] - 1 + if dim==0: + return data[:n, :, :] + if dim==1: + return data[:, :n, :] + if dim==2: + return data[:, :, :n] + + +def load_audio(paths): + mixture = tf.io.read_file(paths[0]) + vocals = tf.io.read_file(paths[1]) + mixture_audio, _ = tf.audio.decode_wav(mixture, desired_channels=1) + vocal_audio, _ = tf.audio.decode_wav(vocals, desired_channels=1) + return tf.squeeze(mixture_audio, axis=-1), tf.squeeze(vocal_audio, axis=-1)