Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to Train Residual LFQ #103

Open
ZetangForward opened this issue Jan 19, 2024 · 0 comments
Open

How to Train Residual LFQ #103

ZetangForward opened this issue Jan 19, 2024 · 0 comments

Comments

@ZetangForward
Copy link

ZetangForward commented Jan 19, 2024

Hi, I want to train a Residual LFQ model for audio, and this is my core code:

def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
    if padding_mask is not None:
        padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
        x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
        x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
        mask_sum = padding_mask.sum()

    if loss_fn == 'l1':
        loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
    elif loss_fn == 'l2':
        loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
    elif loss_fn == 'linf':
        residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
        # only consider the residual of the padded part
        masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
        values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
        loss = torch.mean(values)
    else:
        assert False, f"Unknown loss_fn {loss_fn}"

    return loss


def training_step(self, batch, batch_idx):
        quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
        quantized_out = self.model.get_output_from_indices(indices)
        reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
        return reconstruction_loss  + commit_loss

model = ResidualLFQ(
        dim = config.lfq.dim,
        codebook_size = config.lfq.codebook_size,
        num_quantizers = config.lfq.num_quantizers
    )

I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.

I wonder two things:

  1. Is the reconstruction loss necessary?
  2. Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added commit_loss and reconstruction_loss together, it is weird that one loss is positive and one is negative ...

I hope to get some suggestions @kashif @lucidrains Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant