Skip to content

Commit

Permalink
run on numina
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 7, 2025
1 parent 6e2136b commit a58161d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
6 changes: 6 additions & 0 deletions conf/rl_numina.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- rl_gsm8k
- _self_

model_path: /mnt/llmd/base_models/Meta-Llama-3.1-8B-Instruct
dataset_name: numina
4 changes: 2 additions & 2 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def process_math_test(item):
return
sample = {
"dataset": "math-cot",
"level": item["level"],
"type": item["type"],
#"level": item["level"],
#"type": item["type"],
"task": question,
"answer": answer
}
Expand Down
19 changes: 15 additions & 4 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def extract_tape_training_samples(
tape_prompt_tokens = 0
tape_output_tokens = 0
match cfg.dataset_name:
case "numina":
eval_fn = eval_math
extract_fn = extract_math_answer
case "math":
eval_fn = eval_math
extract_fn = extract_math_answer
Expand Down Expand Up @@ -286,17 +289,25 @@ def main(cfg: DictConfig):

match cfg.dataset_name:
case "math":
dataset_long_name = "hendrycks/competition_math"
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
case "gsm8k":
dataset_long_name = "openai/gsm8k"
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
process_fn = process_gsm8k_test
case "numina":
train_dataset_long_name = "AI-MO/NuminaMath-CoT"
#TODO: think of a good test set
test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

train_dataset = load_dataset(dataset_long_name, "main", split="train", trust_remote_code=True)
if cfg.dataset_name == "numina":
train_dataset = load_dataset(train_dataset_long_name, split="train", trust_remote_code=True)
else:
train_dataset = load_dataset(train_dataset_long_name, "main", split="train", trust_remote_code=True)
train_samples = [process_fn(s) for s in train_dataset]
test_dataset = load_dataset(dataset_long_name, "main", split="test", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, "main", split="test", trust_remote_code=True)
test_samples = [process_fn(s) for s in test_dataset]
logger.info(f"Loaded {len(train_samples)} training samples")
logger.info(f"Loaded {len(test_samples)} test samples")
Expand Down
6 changes: 4 additions & 2 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,13 @@ def _cleanup(self) -> None:
f.close()

def __enter__(self) -> "VLLMServiceManager":
self._start_service()
#self._start_service()
self.ports = [8080, 8081]
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self._cleanup()
#self._cleanup()
return None

def get_stats(self):
return self.stats
Expand Down

0 comments on commit a58161d

Please sign in to comment.