diff --git a/tapeagents/finetune/rl/__init__.py b/tapeagents/finetune/rl/__init__.py index cf861c08..7b3a4318 100644 --- a/tapeagents/finetune/rl/__init__.py +++ b/tapeagents/finetune/rl/__init__.py @@ -108,7 +108,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc log_p_weights = advantages if config.use_advantages else rewards log_p_weights = torch.clamp(log_p_weights, min=0) if config.relu_log_p_weights else log_p_weights # Second compute the approximated KL, see https://arxiv.org/pdf/2402.03300 eq 4 - log_ratio_ref_new = ref_logprobs - new_log_probs + log_ratio_ref_new = torch.clamp(ref_logprobs - new_log_probs, -10, 10) approx_kl = torch.exp(log_ratio_ref_new) - log_ratio_ref_new - 1 # Schulman KL approx match config.algo: case "grpo":