Skip to content

Commit

Permalink
filter out the nans
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 8, 2025
1 parent a2780bd commit d2785f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 1 addition & 4 deletions tapeagents/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc
loss = -masked_mean(new_log_probs * log_p_weights - config.kl_coef * approx_kl, masks_)
case _:
raise ValueError(f"Unknown algorithm {config.algo}")
if not torch.isfinite(loss).all():
logger.warning("Loss is not finite and will be discarded")
loss = torch.nan_to_num(loss)


stats = {
"max_new_log_probs": new_log_probs[masks_].max().item(),
"max_ratio_new_old": ratio_new_old[masks_].max().item(),
Expand Down
5 changes: 5 additions & 0 deletions tapeagents/finetune/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def to_dict(self):
def get_avg_rl_stats(rl_stats):
avg_rl_stats: dict[str, float] = {}
for k, v in rl_stats.items():
if not np.isfinite(v).all():
continue
if "min" in k:
op = torch.min
elif "max" in k:
Expand All @@ -32,6 +34,7 @@ def get_avg_rl_stats(rl_stats):

def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute sum of tensor with a masked values."""
values = torch.nan_to_num(values, nan=0.0)
if axis is not None:
return (values * mask).sum(axis=axis) # type: ignore
else:
Expand All @@ -40,6 +43,8 @@ def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] =

def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
# set the value to 0 if it is not finite
values = torch.nan_to_num(values, nan=0.0)
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis) # type: ignore
else:
Expand Down

0 comments on commit d2785f2

Please sign in to comment.