You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I want to train a Residual LFQ model for audio, and this is my core code:
def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
mask_sum = padding_mask.sum()
if loss_fn == 'l1':
loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
elif loss_fn == 'l2':
loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
elif loss_fn == 'linf':
residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
# only consider the residual of the padded part
masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
loss = torch.mean(values)
else:
assert False, f"Unknown loss_fn {loss_fn}"
return loss
def training_step(self, batch, batch_idx):
quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
quantized_out = self.model.get_output_from_indices(indices)
reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
return reconstruction_loss + commit_loss
model = ResidualLFQ(
dim = config.lfq.dim,
codebook_size = config.lfq.codebook_size,
num_quantizers = config.lfq.num_quantizers
)
I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.
I wonder two things:
Is the reconstruction loss necessary?
Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added commit_loss and reconstruction_loss together, it is weird that one loss is positive and one is negative ...
Hi, I want to train a Residual LFQ model for audio, and this is my core code:
I use
reconstruction_loss
andcommit_loss
to jointly update the ResidualLFQ model.I wonder two things:
commit_loss
andreconstruction_loss
together, it is weird that one loss is positive and one is negative ...I hope to get some suggestions @kashif @lucidrains Thank you
The text was updated successfully, but these errors were encountered: