Skip to content

Commit

Permalink
clamp log_ratio_ref_new
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 9, 2025
1 parent d66b0d6 commit e7f8c7a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tapeagents/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc
log_p_weights = advantages if config.use_advantages else rewards
log_p_weights = torch.clamp(log_p_weights, min=0) if config.relu_log_p_weights else log_p_weights
# Second compute the approximated KL, see https://arxiv.org/pdf/2402.03300 eq 4
log_ratio_ref_new = ref_logprobs - new_log_probs
log_ratio_ref_new = torch.clamp(ref_logprobs - new_log_probs, -10, 10)
approx_kl = torch.exp(log_ratio_ref_new) - log_ratio_ref_new - 1 # Schulman KL approx
match config.algo:
case "grpo":
Expand Down

0 comments on commit e7f8c7a

Please sign in to comment.