Skip to content

Commit

Permalink
Merge branch 'eurus' into now-reasoner-eurus
Browse files Browse the repository at this point in the history
  • Loading branch information
rizar committed Jan 22, 2025
2 parents 97ebcac + 666c3d5 commit fecf3bc
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 87 deletions.
13 changes: 10 additions & 3 deletions conf/rl_eurus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@ defaults:
model_path: /mnt/llmd/base_models/Eurus-2-7B-SFT
dataset_name: eurus
output_dir: outputs/rl_eurus
finetune:
max_seq_len: 4096
attempts: 1
llm:
parameters:
# CoT are much longer, but the model only has 4096 tokens context
max_tokens: 3072

# EURUS already apply this template: {task}\n\nPresent the answer in LaTex format: \\boxed{Your answer}
task_template: |-
{task}
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/eval/system_prompt.md?plain=1
# but note that sometimes they do not include the newline at the beginning
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/data_preprocessing/sft_prompt.py#L1
system_prompt: |-
\nWhen tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process.\n\n[ASSESS]\n\n[ADVANCE]\n\n[VERIFY]\n\n[SIMPLIFY]\n\n[SYNTHESIZE]\n\n[PIVOT]\n\n[OUTPUT]\n\nYou should strictly follow the format below:\n\n[ACTION NAME]\n\n# Your action step 1\n\n# Your action step 2\n\n# Your action step 3\n\n...\n\nNext action: [NEXT ACTION NAME]\n\n

test_every_n_iterations: 1
vllm_config:
vllm_kwargs:
--gpu-memory-utilization: 0.99
--max-num-seqs: 1024
--max-num-seqs: 512
--pipeline-parallel-size: 1
--tensor-parallel-size: 1
11 changes: 7 additions & 4 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ force_restart: false
max_iterations: 1000
use_rejection_sampling: false
llm:
max_prompt_length: 1024
parameters:
max_tokens: 3072
max_tokens: 1024
temperature: 0.7
test_llm:
max_prompt_length: ${..llm.max_prompt_length}
parameters:
max_tokens: ${...llm.parameters.max_tokens}
temperature: 0.
Expand All @@ -34,7 +32,12 @@ finetune:
# One step is one weight update. See the finetuning configuration
# for the info in how many sequences are used for each weight update.
save_checkpoint_steps: 10
seq_length: ${..llm.parameters.max_tokens} + ${..llm.max_prompt_length}
seq_length: 4096

system_prompt: null
# Same prompt as https://github.com/deepseek-ai/DeepSeek-Math/blob/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/run_subset_parallel.py#L26
task_template: |-
{task}\nPlease reason step by step, and put your final answer within \boxed{{}}.
vllm_config:
vllm_kwargs:
Expand Down
19 changes: 9 additions & 10 deletions examples/rl_gsm8k/cot_math_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@

logger = logging.getLogger(__name__)

# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/eval/system_prompt.md?plain=1
# but note that sometimes they do not include the newline at the beginning
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/data_preprocessing/sft_prompt.py#L1
EURUS_SYSTEM_PROMPT = """\nWhen tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process.\n\n[ASSESS]\n\n[ADVANCE]\n\n[VERIFY]\n\n[SIMPLIFY]\n\n[SYNTHESIZE]\n\n[PIVOT]\n\n[OUTPUT]\n\nYou should strictly follow the format below:\n\n[ACTION NAME]\n\n# Your action step 1\n\n# Your action step 2\n\n# Your action step 3\n\n...\n\nNext action: [NEXT ACTION NAME]\n\n"""

class Task(Observation):
kind: Literal["task"] = "task"
task: str
template: str = Field(
description="Template for the task. Should contain a {task} placeholder for the task text.", default="{task}"
)

def llm_view(self, indent: int | None = 2) -> str:
# Same prompt as https://github.com/deepseek-ai/DeepSeek-Math/blob/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/run_subset_parallel.py#L26
#return f"{self.task}\nPlease reason step by step, and put your final answer within " + "\\boxed{}."
return self.task
return self.template.format(task=self.task)


class ReasoningThought(Thought):
Expand Down Expand Up @@ -58,6 +55,8 @@ class ReasoningThought(Thought):


class ReasoningNode(MonoNode):
trim_tape_when_too_long: bool = False

def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, None, None]:
try:
step = ReasoningThought(reasoning=completion)
Expand All @@ -73,17 +72,17 @@ def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, N
#### Agent and Environment ####
class CoTMathAgent(Agent):
@classmethod
def create(cls, llm: LLM):
def create(cls, system_prompt: str, llm: LLM):
agent = super().create(
llm,
nodes=[
ReasoningNode(
name="cot",
agent_step_cls=MathAgentStep,
system_prompt=EURUS_SYSTEM_PROMPT,
system_prompt=system_prompt if system_prompt else "",
),
],
max_iterations=1,
)
agent.store_llm_calls = True
return agent
return agent
8 changes: 0 additions & 8 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@ def process_eurus_test(item):
if "ability" not in item:
# math 500 test set
answer = [item["expected_answer"]]
solution = item["solution"]
# Eurus will produce \\ as \\\\
solution = solution.replace("\\ ", "\\\\ ")
answer2 = extract_math_answer(item["problem"] , solution, task="cot")
if answer2 != answer:
print("Answer mismatch")
print("Old answer:", answer2)
print("New answer:", answer)
return {
"dataset": "math500",
# Same prompt as https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/README.md?plain=1#L93
Expand Down
45 changes: 10 additions & 35 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from typing import Dict, List, Tuple

import hydra
import numpy as np
import torch
import wandb
from datasets import load_dataset
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from tqdm import tqdm
Expand All @@ -34,21 +31,16 @@
)
from .deepseek_math_eval.answer_extraction import extract_last_single_answer, extract_math_answer
from .deepseek_math_eval.eval_script import eval_last_single_answer, eval_math
from .deepseek_math_eval.process_utils import process_eurus_test, process_gsm8k_test, process_math_test
from .utils import (
VLLMServiceManager,
calculate_stats,
clean_up,
launch_training,
load_datasets,
load_state,
save_state,
setup_logging,
)
from tapeagents.orchestrator import main_loop

from .cot_math_agent import CoTMathAgent, RLMathTape, Task
from .utils import VLLMServiceManager, calculate_stats, clean_up, launch_training, load_state, save_state, setup_logging


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,6 +78,7 @@ def convert_problems_to_tapes(problems: list, cfg: DictConfig) -> list[RLMathTap
for problem in problems:
start_step = Task(
task=problem["task"],
template=cfg.task_template,
metadata=StepMetadata(
other={
"value": problem["answer"],
Expand Down Expand Up @@ -152,7 +145,7 @@ def extract_tape_training_samples(
tape_output_tokens += llm_call.output_length_tokens

if llm_call.output_length_tokens >= cfg.llm.parameters.max_tokens:
# Output is too long, ignore this sample
# ignore this sample
# this will be recorded in output_tokens_overflow
continue

Expand Down Expand Up @@ -289,27 +282,7 @@ def main(cfg: DictConfig):
if cfg.force_restart:
clean_up(exp_path, state, state_path)

match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
process_fn = process_gsm8k_test
case "eurus":
train_dataset_long_name = "PRIME-RL/Eurus-2-RL-Data"
test_dataset_long_name = "alexpiche/math_test_cleaned"
process_fn = process_eurus_test
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

train_dataset = load_dataset(train_dataset_long_name, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, split="test", trust_remote_code=True)
test_samples = [process_fn(s) for s in tqdm(test_dataset, desc="Processing test samples") if process_fn(s) is not None]
train_samples = [process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None]
logger.info(f"Loaded {len(train_samples)} training samples")
logger.info(f"Loaded {len(test_samples)} test samples")

train_samples, test_samples = load_datasets(cfg)
conf_dir = exp_path / "conf"
os.makedirs(conf_dir, exist_ok=True)
finetune_path = exp_path / "finetune"
Expand Down Expand Up @@ -345,7 +318,6 @@ def main(cfg: DictConfig):
use_cache=False,
collect_logprobs=True,
observe_llm_calls=False,
max_prompt_length=cfg.llm.max_prompt_length,
)
for base_url in vllm_service_manager.get_base_urls()
]
Expand All @@ -358,17 +330,20 @@ def main(cfg: DictConfig):
parameters=cfg.test_llm.parameters,
use_cache=False,
observe_llm_calls=False,
max_prompt_length=cfg.test_llm.max_prompt_length,
)
for base_url in vllm_service_manager.get_base_urls()
]

train_agent_replicas = [CoTMathAgent.create(llm=llm) for llm in train_llms]
train_agent_replicas = [
CoTMathAgent.create(system_prompt=cfg.system_prompt, llm=llm) for llm in train_llms
]

splits = [("train", train_agent_replicas, train_tapes)]
if state["iteration"] % cfg.test_every_n_iterations == 0 and cfg.test_every_n_iterations > 0:
test_tapes = convert_problems_to_tapes(test_samples, cfg)
test_agent_replicas = [CoTMathAgent.create(llm=llm) for llm in test_llms]
test_agent_replicas = [
CoTMathAgent.create(system_prompt=cfg.system_prompt, llm=llm) for llm in test_llms
]
splits.append(("test", test_agent_replicas, test_tapes))
for split_name, agent_replicas, tapes in splits:
tapes_dir = exp_path / "tapes" / split_name / str(state["iteration"])
Expand Down
46 changes: 39 additions & 7 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, TextIO, Union
from typing import Dict, List, Optional, TextIO, Tuple, Union

import numpy as np
import psutil
import requests
import torch
from datasets import load_dataset
from omegaconf import DictConfig
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from transformers import PreTrainedTokenizer

from tapeagents.llms import LLMOutput, Prompt
from tapeagents.config import is_debug_mode
from tapeagents.llms import LLMOutput, Prompt

from .deepseek_math_eval.process_utils import process_eurus_test, process_gsm8k_test, process_math_test

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -304,7 +309,7 @@ def remove_dir(directory: Path):
save_state(state, state_path)

logger.info("Cleaning up checkpoints and training state")
# list of log files to erase
# list of log files to erase
log_files = list(target_path.glob("*.log"))

for file in log_files:
Expand Down Expand Up @@ -393,15 +398,11 @@ def launch_training(config_dir: str, config_name: str, accelerate_cfg_path: str,

logger.info(f"Launching training with command: {' '.join(base_cmd)}")
try:
# set env variable TORCH_NCCL_ENABLE_MONITORING=0
env = os.environ.copy()
env["TORCH_NCCL_ENABLE_MONITORING"] = "0"
subprocess.run(
base_cmd,
check=True, # Raises CalledProcessError if return code != 0
text=True,
capture_output=False,
env=env
)

except subprocess.CalledProcessError as e:
Expand All @@ -426,3 +427,34 @@ def get_tokens_from_hf_tokenizer(tokenizer: PreTrainedTokenizer | None, prompt:
output_token_ids = text_token_ids[len(prompt_token_ids) :]
output_tokens = [tokenizer.decode(output_token_id) for output_token_id in output_token_ids]
return output_tokens


def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
builder_config = "main"
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
process_fn = process_gsm8k_test
builder_config = "main"
case "eurus":
train_dataset_long_name = "PRIME-RL/Eurus-2-RL-Data"
test_dataset_long_name = "alexpiche/math_test_cleaned"
process_fn = process_eurus_test
builder_config = "default"
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True)
train_samples = [
process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None
]
test_samples = [
process_fn(s) for s in tqdm(test_dataset, desc="Processing test samples") if process_fn(s) is not None
]
logger.info(f"Loaded {len(train_samples)} training samples")
logger.info(f"Loaded {len(test_samples)} test samples")
return train_samples, test_samples
4 changes: 1 addition & 3 deletions tapeagents/finetune/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import time
from concurrent import futures
from functools import partial
from typing import Any, Callable, Iterable, Sequence

Expand All @@ -18,7 +16,6 @@
from tapeagents.core import TrainingText

from .context import accelerator, logger
from .logging_ import log_time
from .rl import RL_DATA_COLUMNS, prepare_rl_fields
from .types import DataArgs, DataPartArgs

Expand Down Expand Up @@ -244,6 +241,7 @@ def create_dataloader(
logger.info(f"Merged data fingerprint: {data._fingerprint}")

if rl_data_callback is not None:
accelerator.wait_for_everyone()
data = rl_data_callback(dataset=data, columns=columns, collate_fn=collate_fn)

if n_examples:
Expand Down
Loading

0 comments on commit fecf3bc

Please sign in to comment.