-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdrum_diffusion.yaml
113 lines (99 loc) · 5.11 KB
/
drum_diffusion.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# @package _global_
# Unconditional Audio Waveform Diffusion
# To execute this experiment on a single GPU, run:
# python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path=<path/to/your/train/data>
module: main.diffusion_module
batch_size: 1 # mini-batch size (increase to speed up at the cost of memory)
accumulate_grad_batches: 32 # use to increase batch size on single GPU -> effective batch size = (batch_size * accumulate_grad_batches)
num_workers: 8 # num workers for data loading
sampling_rate: 44100 # sampling rate (44.1kHz is the music industry standard)
length: 32768 # Length of audio in samples (32768 samples @ 44.1kHz ~ 0.75 seconds)
channels: 2 # stereo audio
val_log_every_n_steps: 1000 # Logging interval (Validation and audio generation every n steps)
# ckpt_every_n_steps: 4000 # Use if multiple checkpoints wanted
model:
_target_: ${module}.Model # pl model wrapper
lr: 1e-4 # optimizer learning rate
lr_beta1: 0.95 # beta1 param for Adam optimizer
lr_beta2: 0.999 # beta2 param for Adam optimzer
lr_eps: 1e-6 # epsilon for optimizer (to avoid div by 0)
lr_weight_decay: 1e-3 # weight decay regularization param
ema_beta: 0.995 # EMA model (exponential-moving-average) beta
ema_power: 0.7 # EMA model gradiaent norm param
model:
_target_: audio_diffusion_pytorch.DiffusionModel # Waveform diffusion model
net_t:
_target_: ${module}.UNetT # The model type used for diffusion (U-Net V0 in this case)
in_channels: 2 # U-Net: number of input/output (audio) channels
channels: [32, 32, 64, 64, 128, 128, 256, 256] # U-Net: channels at each layer
factors: [1, 2, 2, 2, 2, 2, 2, 2] # U-Net: downsampling and upsampling factors at each layer
items: [2, 2, 2, 2, 2, 2, 4, 4] # U-Net: number of repeating items at each layer
attentions: [0, 0, 0, 0, 0, 1, 1, 1] # U-Net: attention enabled/disabled at each layer
attention_heads: 8 # U-Net: number of attention heads per attention item
attention_features: 64 # U-Net: number of attention features per attention item
# To specify train-valid datasets, datamodule must be reconfigured
datamodule:
_target_: main.diffusion_module.Datamodule
dataset:
_target_: audio_data_pytorch.WAVDataset
path: ./data/wav_dataset # can overried when calling train.py
recursive: True
sample_rate: ${sampling_rate}
transforms:
_target_: audio_data_pytorch.AllTransform
crop_size: ${length} # One-shots, so no random crop
stereo: True
source_rate: ${sampling_rate}
target_rate: ${sampling_rate}
loudness: -20 # normalize loudness
val_split: 0.1 # split data into validation
batch_size: ${batch_size}
num_workers: ${num_workers}
pin_memory: True
callbacks:
rich_progress_bar:
_target_: pytorch_lightning.callbacks.RichProgressBar
# _target_: pytorch_lightning.callbacks.TQDMProgressBar # use if RichProgressBar creates issues
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "valid_loss" # name of the logged metric which determines when model is improving
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
mode: "min" # can be "max" or "min"
verbose: False
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
filename: '{epoch:02d}-{valid_loss:.3f}'
# every_n_train_steps: ${ckpt_every_n_steps} # Use if multiple checkpoints wanted
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
max_depth: 2
audio_samples_logger:
_target_: main.diffusion_module.SampleLogger
num_items: 4 # number of separate samples to be generated
channels: ${channels} # number of audio channels
sampling_rate: ${sampling_rate} # audio sampling rate
length: ${length} # length of generated sample
sampling_steps: [50] # number of steps per sample
use_ema_model: True # Use EMA for logger inference
loggers:
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
project: ${oc.env:WANDB_PROJECT} # defined in env var
entity: ${oc.env:WANDB_ENTITY} # defined in env var
name: unconditional_diffusion # name of run
# offline: False # set True to store all logs only locally
job_type: "train"
group: "" # Set a group name if desired
save_dir: ${logs_dir}
trainer:
_target_: pytorch_lightning.Trainer
gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
precision: 16 # Precision used for tensors (`32` offers higher precision, but `16` is used to save memory)
min_epochs: 0 # minimum number of epochs
max_epochs: -1 # max number of epochs (-1 = infinite run)
enable_model_summary: False
log_every_n_steps: 1 # Logs training metrics every n steps
# limit_val_batches: 10 # Use to limit the number of valid batches run (e.g. 10 stops training at 10 batches)
check_val_every_n_epoch: null
val_check_interval: ${val_log_every_n_steps} # Validation interval (check valid set and generate audio every n steps)
accumulate_grad_batches: ${accumulate_grad_batches} # use to increase batch size on single GPU