Skip to content

Commit

Permalink
move load_datasets to orchestarte_rl
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 22, 2025
1 parent 666c3d5 commit 0efa317
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
34 changes: 33 additions & 1 deletion examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import hydra
import torch
from datasets import load_dataset
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from tqdm import tqdm
Expand All @@ -31,12 +32,12 @@
)
from .deepseek_math_eval.answer_extraction import extract_last_single_answer, extract_math_answer
from .deepseek_math_eval.eval_script import eval_last_single_answer, eval_math
from .deepseek_math_eval.process_utils import process_eurus_test, process_gsm8k_test, process_math_test
from .utils import (
VLLMServiceManager,
calculate_stats,
clean_up,
launch_training,
load_datasets,
load_state,
save_state,
setup_logging,
Expand All @@ -45,6 +46,37 @@
logger = logging.getLogger(__name__)


def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
builder_config = "main"
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
process_fn = process_gsm8k_test
builder_config = "main"
case "eurus":
train_dataset_long_name = "PRIME-RL/Eurus-2-RL-Data"
test_dataset_long_name = "alexpiche/math_test_cleaned"
process_fn = process_eurus_test
builder_config = "default"
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True)
train_samples = [
process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None
]
test_samples = [
process_fn(s) for s in tqdm(test_dataset, desc="Processing test samples") if process_fn(s) is not None
]
logger.info(f"Loaded {len(train_samples)} training samples")
logger.info(f"Loaded {len(test_samples)} test samples")
return train_samples, test_samples


def batch_annotate_traces_with_ref_logprobs(llm: TrainableLLM, traces: List[TrainingText]):
prompt_token_ids = []
completion_token_ids = []
Expand Down
36 changes: 0 additions & 36 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
import psutil
import requests
import torch
from datasets import load_dataset
from omegaconf import DictConfig
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from transformers import PreTrainedTokenizer

from tapeagents.config import is_debug_mode
from tapeagents.llms import LLMOutput, Prompt

from .deepseek_math_eval.process_utils import process_eurus_test, process_gsm8k_test, process_math_test

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -427,34 +422,3 @@ def get_tokens_from_hf_tokenizer(tokenizer: PreTrainedTokenizer | None, prompt:
output_token_ids = text_token_ids[len(prompt_token_ids) :]
output_tokens = [tokenizer.decode(output_token_id) for output_token_id in output_token_ids]
return output_tokens


def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
process_fn = process_math_test
builder_config = "main"
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
process_fn = process_gsm8k_test
builder_config = "main"
case "eurus":
train_dataset_long_name = "PRIME-RL/Eurus-2-RL-Data"
test_dataset_long_name = "alexpiche/math_test_cleaned"
process_fn = process_eurus_test
builder_config = "default"
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True)
train_samples = [
process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None
]
test_samples = [
process_fn(s) for s in tqdm(test_dataset, desc="Processing test samples") if process_fn(s) is not None
]
logger.info(f"Loaded {len(train_samples)} training samples")
logger.info(f"Loaded {len(test_samples)} test samples")
return train_samples, test_samples

0 comments on commit 0efa317

Please sign in to comment.