Skip to content

Commit

Permalink
Merge pull request #167 from ServiceNow/eurus
Browse files Browse the repository at this point in the history
[WIP] Eurus
  • Loading branch information
AlexPiche authored Jan 24, 2025
2 parents 0b8a786 + ebf5aaa commit d7258ba
Show file tree
Hide file tree
Showing 14 changed files with 21,603 additions and 2,871 deletions.
29 changes: 29 additions & 0 deletions conf/rl_eurus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
defaults:
- rl_gsm8k
- _self_
model_path: /mnt/llmd/base_models/Eurus-2-7B-SFT
dataset_name: eurus
output_dir: outputs/rl_eurus
attempts: 1
llm:
parameters:
# CoT are much longer, but the model only has 4096 tokens context
max_tokens: 3072

# EURUS already formatted the dataset as {task}\n\nPresent the answer in LaTex format: \\boxed{Your answer}
# thus we can simply use the identity task template
task_template: |-
{task}
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/eval/system_prompt.md?plain=1
# but note that sometimes they do not include the newline at the beginning
# https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/data_preprocessing/sft_prompt.py#L1
system_prompt: |-
\nWhen tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process.\n\n[ASSESS]\n\n[ADVANCE]\n\n[VERIFY]\n\n[SIMPLIFY]\n\n[SYNTHESIZE]\n\n[PIVOT]\n\n[OUTPUT]\n\nYou should strictly follow the format below:\n\n[ACTION NAME]\n\n# Your action step 1\n\n# Your action step 2\n\n# Your action step 3\n\n...\n\nNext action: [NEXT ACTION NAME]\n\n

test_every_n_iterations: 1
vllm_config:
vllm_kwargs:
--gpu-memory-utilization: 0.99
--max-num-seqs: 512
--pipeline-parallel-size: 1
--tensor-parallel-size: 1
14 changes: 11 additions & 3 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ defaults:
dataset_name: gsm8k
n_workers_per_gpu: 32
get_logprobs_workers_per_gpu: 4
gpus_per_model_instance: 1
max_loops: 1
test_every_n_iterations: 5
model_path: /mnt/llmd/base_models/deepseek-math-7b-instruct
Expand All @@ -16,7 +15,7 @@ max_iterations: 1000
use_rejection_sampling: false
llm:
parameters:
max_tokens: 1024
max_tokens: 3072
temperature: 0.7
test_llm:
parameters:
Expand All @@ -33,7 +32,14 @@ finetune:
# One step is one weight update. See the finetuning configuration
# for the info in how many sequences are used for each weight update.
save_checkpoint_steps: 10
seq_length: ${..llm.parameters.max_tokens}
seq_length: 4096

system_prompt: ""
# Same prompt as https://github.com/deepseek-ai/DeepSeek-Math/blob/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/run_subset_parallel.py#L26
task_template: |-
{task}\nPlease reason step by step, and put your final answer within \boxed{{}}.
overflow_reward: 0
max_prompt_length: 1024

vllm_config:
vllm_kwargs:
Expand All @@ -44,6 +50,8 @@ vllm_config:
--max-num-seqs: 1024
--enforce-eager: ""
--return-tokens-as-token-ids: ""
--pipeline-parallel-size: 1
--tensor-parallel-size: 1
actor_vllm_kwargs:
--num-scheduler-steps: 16
ref_vllm_kwargs:
Expand Down
33 changes: 28 additions & 5 deletions examples/rl_gsm8k/cot_math_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import Annotated, Generator, Literal, TypeAlias, Union
from typing import Annotated, Any, Generator, Literal, TypeAlias, Union

from pydantic import Field

from tapeagents.agent import Agent
from tapeagents.core import (
LLMOutputParsingFailureAction,
Observation,
Prompt,
Step,
Tape,
Thought,
Expand All @@ -20,10 +21,12 @@
class Task(Observation):
kind: Literal["task"] = "task"
task: str
template: str = Field(
description="Template for the task. Should contain a {task} placeholder for the task text.", default="{task}"
)

def llm_view(self, indent: int | None = 2) -> str:
# Same prompt as https://github.com/deepseek-ai/DeepSeek-Math/blob/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/run_subset_parallel.py#L26
return f"{self.task}\nPlease reason step by step, and put your final answer within " + "\\boxed{}."
return self.template.format(task=self.task)


class ReasoningThought(Thought):
Expand Down Expand Up @@ -51,6 +54,8 @@ class ReasoningThought(Thought):


class ReasoningNode(MonoNode):
max_prompt_length: int = 1024

def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, None, None]:
try:
step = ReasoningThought(reasoning=completion)
Expand All @@ -62,21 +67,39 @@ def parse_completion(self, completion: str, prompt_id: str) -> Generator[Step, N
return
yield step

def make_prompt(self, agent: Any, tape: Tape) -> Prompt:
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})

# the tape is only step long and it is the task
task = tape.steps[0]
assert isinstance(task, Task), f"Expected a Task, got {task.__class__.__name__}"
messages.append({"role": "user", "content": task.llm_view()})
prompt_token_ids = agent.llm.tokenizer.apply_chat_template(
messages, add_special_tokens=True, add_generation_prompt=True
)
prompt_token_ids = prompt_token_ids[-self.max_prompt_length :]
return Prompt(messages=messages, token_ids=prompt_token_ids)


#### Agent and Environment ####
class CoTMathAgent(Agent):
@classmethod
def create(cls, llm: LLM):
def create(cls, system_prompt: str, llm: LLM, max_prompt_length: int):
agent = super().create(
llm,
nodes=[
ReasoningNode(
name="cot",
agent_step_cls=MathAgentStep,
store_llm_calls=True,
system_prompt=system_prompt, # if system_prompt else "",
max_prompt_length=max_prompt_length,
),
],
max_iterations=1,
)
agent.store_llm_calls = True
if agent.llm.tokenizer is None:
agent.llm.load_tokenizer()
return agent
9 changes: 6 additions & 3 deletions examples/rl_gsm8k/deepseek_math_eval/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def math_equal(
else:
if item == prediction:
return True
except Exception:
except Exception as e:
raise e
continue
return False
except:
Expand Down Expand Up @@ -299,8 +300,10 @@ def math_equal(
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
# symbolic equal hangs
#if symbolic_equal(prediction, reference):
# return True
pass

return False

Expand Down
27 changes: 27 additions & 0 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@
from examples.rl_gsm8k.deepseek_math_eval.eval_utils import parse_ground_truth


def process_eurus_test(item):
if "ability" not in item:
# math 500 test set
answer = [item["expected_answer"]]
return {
"dataset": "math500",
# Same prompt as https://github.com/PRIME-RL/PRIME/blob/49a58a8e4afd464f559f8d9f80418052f29cf3e4/README.md?plain=1#L93
"task": item["problem"] + "\n\nPresent the answer in LaTex format: \\boxed{Your answer}",
"answer": answer
}
else:
# Eurus train set
if item["ability"] != "math":
return None
answer = item["reward_model"]["ground_truth"]
# format matrices
# remove new lines
answer = answer.replace("\n", "")
answer = "\\boxed{" + answer + "}"
answer = extract_math_answer(item["prompt"][1]["content"], answer, task="cot")
return {
"dataset": item["data_source"],
"task": item["prompt"][1]["content"],
"answer": answer
}


def process_gsm8k_test(item):
_, answer = parse_ground_truth(item, "gsm8k")
sample = {"dataset": "gsm8k-cot", "task": item["question"], "answer": answer}
Expand Down
Loading

0 comments on commit d7258ba

Please sign in to comment.