Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss: Noise Constrastive Estimation #65

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
25 changes: 19 additions & 6 deletions i6_models/losses/nce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch import nn
from torch.nn import functional as F
from typing import Optional
import math


class NoiseContrastiveEstimationLossV1(nn.Module):
christophmluscher marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -17,31 +16,45 @@ class NoiseContrastiveEstimationLossV1(nn.Module):
def __init__(
self,
num_samples: int,
*,
model: nn.Module,
noise_distribution_sampler: nn.Module,
log_norm_term: Optional[float] = None,
reduction: str = "none",
device: Optional[str] = None,
) -> None:
"""
Noise contrastive estimation loss implementation.
Used to estimate the softmax. Normally for very large softmax sizes, for example word-level LM.

:param num_samples: num of samples for the estimation, normally a value between 1000-4000.
2000 is a good starting point.
:param model: model on which the NCE loss is to be applied.
:param noise_distribution_sampler: for example `i6_model.samplers.LogUniformSampler`.
:param log_norm_term: normalisation term for true/sampled logits.
:param reduction: reduction method for binary cross entropy.
:param device: device where the loss will be placed.
"""
super().__init__()

self.num_samples = num_samples
self.model = model # only used to access weights of output layer for NCE computation
self.noise_distribution_sampler = noise_distribution_sampler
self.log_norm_term = log_norm_term
self.device = device

self._bce = nn.BCEWithLogitsLoss(reduction=reduction)

def forward(self, data: torch.Tensor, target: torch.Tensor):
def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# input: [B x T, F] target: [B x T]

with torch.no_grad():
samples = self.noise_distribution_sampler.sample(self.num_samples).cuda()

# log-probabilities for the noise distribution k * q(w|h)
sampled_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob(
samples
) # [num_samples]
true_sample_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob(target) # [B x T]
ws = torch.log(torch.Tensor([self.num_samples]))
christophmluscher marked this conversation as resolved.
Show resolved Hide resolved
sampled_prob = ws + self.noise_distribution_sampler.log_prob(samples) # [num_samples]
true_sample_prob = ws + self.noise_distribution_sampler.log_prob(target) # [B x T]

all_classes = torch.cat((target, samples), 0) # [B x T + num_sampled]

Expand Down
Loading