diff --git a/tapeagents/finetune/rl/__init__.py b/tapeagents/finetune/rl/__init__.py index 8c4afb8d..df284644 100644 --- a/tapeagents/finetune/rl/__init__.py +++ b/tapeagents/finetune/rl/__init__.py @@ -130,10 +130,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc loss = -masked_mean(new_log_probs * log_p_weights - config.kl_coef * approx_kl, masks_) case _: raise ValueError(f"Unknown algorithm {config.algo}") - if not torch.isfinite(loss).all(): - logger.warning("Loss is not finite and will be discarded") - loss = torch.nan_to_num(loss) - + stats = { "max_new_log_probs": new_log_probs[masks_].max().item(), "max_ratio_new_old": ratio_new_old[masks_].max().item(), diff --git a/tapeagents/finetune/rl/utils.py b/tapeagents/finetune/rl/utils.py index f2da96ef..531a1cc8 100644 --- a/tapeagents/finetune/rl/utils.py +++ b/tapeagents/finetune/rl/utils.py @@ -20,6 +20,8 @@ def to_dict(self): def get_avg_rl_stats(rl_stats): avg_rl_stats: dict[str, float] = {} for k, v in rl_stats.items(): + if not np.isfinite(v).all(): + continue if "min" in k: op = torch.min elif "max" in k: @@ -32,6 +34,7 @@ def get_avg_rl_stats(rl_stats): def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute sum of tensor with a masked values.""" + values = torch.nan_to_num(values, nan=0.0) if axis is not None: return (values * mask).sum(axis=axis) # type: ignore else: @@ -40,6 +43,8 @@ def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute mean of tensor with a masked values.""" + # set the value to 0 if it is not finite + values = torch.nan_to_num(values, nan=0.0) if axis is not None: return (values * mask).sum(axis=axis) / mask.sum(axis=axis) # type: ignore else: