From a58161dd39ec40c5f61c10faa0d0dbe48d846e44 Mon Sep 17 00:00:00 2001 From: Alexandre Piche Date: Tue, 7 Jan 2025 21:38:21 +0000 Subject: [PATCH] run on numina --- conf/rl_numina.yaml | 6 ++++++ .../deepseek_math_eval/process_utils.py | 4 ++-- examples/rl_gsm8k/orchestrate_rl.py | 19 +++++++++++++++---- examples/rl_gsm8k/utils.py | 6 ++++-- 4 files changed, 27 insertions(+), 8 deletions(-) create mode 100644 conf/rl_numina.yaml diff --git a/conf/rl_numina.yaml b/conf/rl_numina.yaml new file mode 100644 index 00000000..29a4ee43 --- /dev/null +++ b/conf/rl_numina.yaml @@ -0,0 +1,6 @@ +defaults: + - rl_gsm8k + - _self_ + +model_path: /mnt/llmd/base_models/Meta-Llama-3.1-8B-Instruct +dataset_name: numina \ No newline at end of file diff --git a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py index 29f9d25f..374fa8aa 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py +++ b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py @@ -21,8 +21,8 @@ def process_math_test(item): return sample = { "dataset": "math-cot", - "level": item["level"], - "type": item["type"], + #"level": item["level"], + #"type": item["type"], "task": question, "answer": answer } diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 3bc7bb02..aeb0b5e0 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -115,6 +115,9 @@ def extract_tape_training_samples( tape_prompt_tokens = 0 tape_output_tokens = 0 match cfg.dataset_name: + case "numina": + eval_fn = eval_math + extract_fn = extract_math_answer case "math": eval_fn = eval_math extract_fn = extract_math_answer @@ -286,17 +289,25 @@ def main(cfg: DictConfig): match cfg.dataset_name: case "math": - dataset_long_name = "hendrycks/competition_math" + train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math" process_fn = process_math_test case "gsm8k": - dataset_long_name = "openai/gsm8k" + train_dataset_long_name = test_dataset_long_name = "openai/gsm8k" process_fn = process_gsm8k_test + case "numina": + train_dataset_long_name = "AI-MO/NuminaMath-CoT" + #TODO: think of a good test set + test_dataset_long_name = "hendrycks/competition_math" + process_fn = process_math_test case _: raise ValueError(f"Unknown dataset: {cfg.dataset_name}") - train_dataset = load_dataset(dataset_long_name, "main", split="train", trust_remote_code=True) + if cfg.dataset_name == "numina": + train_dataset = load_dataset(train_dataset_long_name, split="train", trust_remote_code=True) + else: + train_dataset = load_dataset(train_dataset_long_name, "main", split="train", trust_remote_code=True) train_samples = [process_fn(s) for s in train_dataset] - test_dataset = load_dataset(dataset_long_name, "main", split="test", trust_remote_code=True) + test_dataset = load_dataset(test_dataset_long_name, "main", split="test", trust_remote_code=True) test_samples = [process_fn(s) for s in test_dataset] logger.info(f"Loaded {len(train_samples)} training samples") logger.info(f"Loaded {len(test_samples)} test samples") diff --git a/examples/rl_gsm8k/utils.py b/examples/rl_gsm8k/utils.py index 6856ab6c..75d1db03 100644 --- a/examples/rl_gsm8k/utils.py +++ b/examples/rl_gsm8k/utils.py @@ -216,11 +216,13 @@ def _cleanup(self) -> None: f.close() def __enter__(self) -> "VLLMServiceManager": - self._start_service() + #self._start_service() + self.ports = [8080, 8081] return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self._cleanup() + #self._cleanup() + return None def get_stats(self): return self.stats