From cb77e17ada466ec537fa8441f343501a3ed92ec0 Mon Sep 17 00:00:00 2001 From: Jordan Prince Tremblay Date: Thu, 9 Jan 2025 22:42:45 -0500 Subject: [PATCH] apply lint --- examples/continue_tapes.py | 2 +- examples/data_science/tape_browser.py | 2 +- examples/delegate.py | 2 +- examples/form_filler/types.py | 1 - examples/form_filler/utils.py | 10 +- examples/gaia_agent/eval.py | 10 +- examples/gaia_agent/scripts/debug.py | 4 +- examples/gaia_agent/scripts/studio.py | 4 +- examples/gsm8k_tuning/math_agent.py | 2 +- examples/optimize/func_templates.py | 2 +- examples/optimize/load_demos.py | 1 - examples/rl_gsm8k/browse.py | 4 +- .../deepseek_math_eval/answer_extraction.py | 104 +++++----- .../deepseek_math_eval/eval_script.py | 116 ++++++----- .../rl_gsm8k/deepseek_math_eval/eval_utils.py | 186 ++++++++++-------- .../ocw_courses_eval_utils.py | 23 ++- .../deepseek_math_eval/process_utils.py | 160 ++++++--------- examples/rl_gsm8k/gather_jsons.py | 2 +- examples/rl_gsm8k/orchestrate_rl.py | 37 ++-- examples/rl_gsm8k/utils.py | 27 +-- tapeagents/agent.py | 12 +- tapeagents/batch.py | 4 +- tapeagents/core.py | 4 +- tapeagents/dialog_tape.py | 2 - tapeagents/finetune/checkpoints.py | 3 +- tapeagents/finetune/finetune.py | 3 +- tapeagents/finetune/logging_.py | 3 +- tapeagents/finetune/rl/__init__.py | 1 + tapeagents/nodes.py | 10 +- tapeagents/observe.py | 1 + tapeagents/orchestrator.py | 8 +- tapeagents/parallel_processing.py | 16 +- tapeagents/steps.py | 2 +- tests/add_sqlite_columns.py | 5 +- tests/test_llm_function.py | 5 +- 35 files changed, 400 insertions(+), 378 deletions(-) diff --git a/examples/continue_tapes.py b/examples/continue_tapes.py index d2b5b373..60425b0e 100644 --- a/examples/continue_tapes.py +++ b/examples/continue_tapes.py @@ -4,7 +4,7 @@ from tapeagents.batch import generate_tapes from tapeagents.dialog_tape import AssistantStep, DialogTape, SystemStep, UserStep from tapeagents.environment import EmptyEnvironment -from tapeagents.llms import TrainableLLM, LLM +from tapeagents.llms import LLM, TrainableLLM from .llama_agent import LLAMAChatBot from .llama_user import LLAMAUserModel diff --git a/examples/data_science/tape_browser.py b/examples/data_science/tape_browser.py index 3c49fe82..447ae017 100644 --- a/examples/data_science/tape_browser.py +++ b/examples/data_science/tape_browser.py @@ -1,6 +1,6 @@ import os -from pathlib import Path import sys +from pathlib import Path from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer from tapeagents.tape_browser import TapeBrowser diff --git a/examples/delegate.py b/examples/delegate.py index 758576ff..42bddd6e 100644 --- a/examples/delegate.py +++ b/examples/delegate.py @@ -4,7 +4,7 @@ from tapeagents.agent import Agent, AgentEvent from tapeagents.core import Action, Prompt, Tape, Thought -from tapeagents.llms import TrainableLLM, LLM, LLMStream +from tapeagents.llms import LLM, LLMStream, TrainableLLM logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/examples/form_filler/types.py b/examples/form_filler/types.py index 8ee134a5..1bbcc8b9 100644 --- a/examples/form_filler/types.py +++ b/examples/form_filler/types.py @@ -2,6 +2,5 @@ from pydantic import Field - FunctionName: TypeAlias = Annotated[str, Field(description="The name of a function.")] ParameterName: TypeAlias = Annotated[str, Field(description="The name of a function parameter.")] diff --git a/examples/form_filler/utils.py b/examples/form_filler/utils.py index 36e9bd91..b4b0d77f 100644 --- a/examples/form_filler/utils.py +++ b/examples/form_filler/utils.py @@ -1,15 +1,15 @@ - from typing import Any, Type + from . import steps def render_chat_template(messages: list[dict[str, Any]], context: dict[str, Any]) -> list[dict[str, Any]]: formatted_chat = [] for message in messages: - formatted_content = message['content'].format(**context) + formatted_content = message["content"].format(**context) formatted_chat.append( dict( - role=message['role'], + role=message["role"], content=formatted_content, ) ) @@ -31,7 +31,7 @@ def sanitize_json_completion(completion: str) -> str: clean_lines = [] for line in lines: line = line.replace("\\", "") # remove all backslashes - line = ' '.join(line.split()) # remove all extra spaces + line = " ".join(line.split()) # remove all extra spaces if line.startswith("```"): tiks_counter += 1 if tiks_counter == 1: @@ -39,7 +39,7 @@ def sanitize_json_completion(completion: str) -> str: elif tiks_counter == 2: break continue - elif line.startswith("[") or line.startswith("{"): # detected start of the json section + elif line.startswith("[") or line.startswith("{"): # detected start of the json section if not opened: opened = True clean_lines = [] diff --git a/examples/gaia_agent/eval.py b/examples/gaia_agent/eval.py index 8ae1996c..c03e80f1 100644 --- a/examples/gaia_agent/eval.py +++ b/examples/gaia_agent/eval.py @@ -103,11 +103,11 @@ def solve_task( retries: int = 3, max_loops: int = 50, ) -> Generator[GaiaTape, None, None]: - """Solve GAIA task. - + """Solve GAIA task. + This function is a generator that yields intermediate tapes during the solving process. The last tape will contain the agent's response. - + """ start_steps = env.task_to_observations(task) solved = None @@ -119,8 +119,8 @@ def solve_task( if partial_tape := (event.agent_tape or event.env_tape): tape = partial_tape tape.metadata = GaiaMetadata.model_validate( - tape.metadata.model_dump() | {"task": task, "level": level} - ) + tape.metadata.model_dump() | {"task": task, "level": level} + ) yield tape if n_search_repetitions(tape) >= 3: break diff --git a/examples/gaia_agent/scripts/debug.py b/examples/gaia_agent/scripts/debug.py index 64f5f7d3..1d59d8ca 100644 --- a/examples/gaia_agent/scripts/debug.py +++ b/examples/gaia_agent/scripts/debug.py @@ -43,9 +43,7 @@ def main(cfg: DictConfig) -> None: env = GaiaEnvironment(vision_lm=llm, code_sandbox=code_sandbox) agent = GaiaAgent.create(llm, **cfg.agent) tape = GaiaTape(steps=env.task_to_observations(task)) - tape.metadata = GaiaMetadata.model_validate( - tape.metadata.model_dump() | {"task": task, "level": cfg.level} - ) + tape.metadata = GaiaMetadata.model_validate(tape.metadata.model_dump() | {"task": task, "level": cfg.level}) step_count = 0 for event in main_loop(agent, tape, env, max_loops=50): if event.agent_event and event.agent_event.step: diff --git a/examples/gaia_agent/scripts/studio.py b/examples/gaia_agent/scripts/studio.py index dacf8a0a..5f5fcb73 100644 --- a/examples/gaia_agent/scripts/studio.py +++ b/examples/gaia_agent/scripts/studio.py @@ -36,12 +36,12 @@ def main(cfg: DictConfig) -> None: env = GaiaEnvironment(vision_lm=llm, attachment_dir=attachment_dir) agent = GaiaAgent.create(llm, **cfg.agent) content = "How many calories in 2 teaspoons of hummus" - if cfg.studio.tape: + if cfg.studio.tape: tape = load_tapes(GaiaTape, cfg.studio.tape, ".json")[0] else: # Uncomment the following line to test video question # content = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?" - tape = GaiaTape(steps=[GaiaQuestion(content=content)]) + tape = GaiaTape(steps=[GaiaQuestion(content=content)]) Studio(agent, tape, CameraReadyRenderer(), env).launch(server_name="0.0.0.0", static_dir=attachment_dir) diff --git a/examples/gsm8k_tuning/math_agent.py b/examples/gsm8k_tuning/math_agent.py index c370fa15..80394752 100644 --- a/examples/gsm8k_tuning/math_agent.py +++ b/examples/gsm8k_tuning/math_agent.py @@ -7,8 +7,8 @@ from tapeagents.agent import Agent from tapeagents.core import ( Action, - LLMOutputParsingFailureAction, FinalStep, + LLMOutputParsingFailureAction, Observation, SetNextNode, Tape, diff --git a/examples/optimize/func_templates.py b/examples/optimize/func_templates.py index efaac217..df289dce 100644 --- a/examples/optimize/func_templates.py +++ b/examples/optimize/func_templates.py @@ -1,5 +1,5 @@ from tapeagents.dialog_tape import ToolResult -from tapeagents.llm_function import Input, LLMFunctionTemplate, AssistantOutput, RationaleOutput, ToolCallOutput +from tapeagents.llm_function import AssistantOutput, Input, LLMFunctionTemplate, RationaleOutput, ToolCallOutput def render_contexts(contexts: list[str]) -> str: diff --git a/examples/optimize/load_demos.py b/examples/optimize/load_demos.py index 0041d719..58039458 100644 --- a/examples/optimize/load_demos.py +++ b/examples/optimize/load_demos.py @@ -11,7 +11,6 @@ UserStep, ) - res_dir = pathlib.Path(__file__).parent.parent.resolve() / "res" diff --git a/examples/rl_gsm8k/browse.py b/examples/rl_gsm8k/browse.py index dfb49cfb..5bac1e90 100644 --- a/examples/rl_gsm8k/browse.py +++ b/examples/rl_gsm8k/browse.py @@ -1,10 +1,10 @@ import os -from pathlib import Path import sys +from pathlib import Path +from examples.rl_gsm8k.cot_math_agent import MathTape from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer from tapeagents.tape_browser import TapeBrowser -from examples.rl_gsm8k.cot_math_agent import MathTape # comment this code out if loading the prompt and completions takes too long for you tape_dir = Path(sys.argv[1]) diff --git a/examples/rl_gsm8k/deepseek_math_eval/answer_extraction.py b/examples/rl_gsm8k/deepseek_math_eval/answer_extraction.py index c2691d9d..598eec78 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/answer_extraction.py +++ b/examples/rl_gsm8k/deepseek_math_eval/answer_extraction.py @@ -1,6 +1,8 @@ import re + import regex + def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] @@ -129,7 +131,7 @@ def strip_string(string): string = string.replace("inf", "\\infty") string = string.replace("+\\inity", "\\infty") - # and + # and # string = string.replace("and", "") string = string.replace("\\mathbf", "") string = string.replace("\\mathrm", "") @@ -139,8 +141,8 @@ def strip_string(string): # quote string.replace("'", "") - string.replace("\"", "") - + string.replace('"', "") + # i, j if "j" in string and "i" not in string: string = string.replace("j", "i") @@ -174,57 +176,60 @@ def strip_string(string): return string + def extract_boxed_answers(text): answers = [] - for piece in text.split('boxed{')[1:]: + for piece in text.split("boxed{")[1:]: n = 0 for i in range(len(piece)): - if piece[i] == '{': + if piece[i] == "{": n += 1 - elif piece[i] == '}': + elif piece[i] == "}": n -= 1 if n < 0: - if i + 1 < len(piece) and piece[i + 1] == '%': + if i + 1 < len(piece) and piece[i + 1] == "%": answers.append(piece[: i + 1]) else: answers.append(piece[:i]) break return answers + def extract_program_output(pred_str): """ extract output between the last ```output\n...\n``` """ if "```output" not in pred_str: return "" - if '```output' in pred_str: - pred_str = pred_str.split('```output')[-1] - if '```' in pred_str: - pred_str = pred_str.split('```')[0] + if "```output" in pred_str: + pred_str = pred_str.split("```output")[-1] + if "```" in pred_str: + pred_str = pred_str.split("```")[0] output = pred_str.strip() return output + def extract_answer(pred_str, exhaust=False): pred = [] - if 'final answer is $' in pred_str and '$. I hope' in pred_str: - tmp = pred_str.split('final answer is $', 1)[1] - pred = [tmp.split('$. I hope', 1)[0].strip()] - elif 'boxed' in pred_str: + if "final answer is $" in pred_str and "$. I hope" in pred_str: + tmp = pred_str.split("final answer is $", 1)[1] + pred = [tmp.split("$. I hope", 1)[0].strip()] + elif "boxed" in pred_str: pred = extract_boxed_answers(pred_str) - elif ('he answer is' in pred_str): - pred = [pred_str.split('he answer is')[-1].strip()] + elif "he answer is" in pred_str: + pred = [pred_str.split("he answer is")[-1].strip()] else: program_output = extract_program_output(pred_str) if program_output != "": # fall back to program pred.append(program_output) - else: # use the last number - pattern = '-?\d*\.?\d+' + else: # use the last number + pattern = "-?\d*\.?\d+" ans = re.findall(pattern, pred_str.replace(",", "")) - if(len(ans) >= 1): + if len(ans) >= 1: ans = ans[-1] else: - ans = '' + ans = "" if ans: pred.append(ans) @@ -242,10 +247,11 @@ def extract_answer(pred_str, exhaust=False): else: return _pred[-1] if _pred else "" + def extract_math_answer(question, reasoning, task): answer = [] for ans in extract_answer(reasoning, exhaust=True): - if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): + if "separated by commas" in question and all(ch not in ans for ch in "()[]"): answer.extend([a.strip() for a in ans.split(",")]) elif regex.search(r"\\text\{\s*and\s*\}", ans): answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) @@ -253,83 +259,93 @@ def extract_math_answer(question, reasoning, task): answer.append(ans.strip()) return answer + def extract_math_few_shot_cot_answer(question, reasoning, task): - if 'Problem:' in reasoning: + if "Problem:" in reasoning: reasoning = reasoning.split("Problem:", 1)[0] return extract_math_answer(question, reasoning, task) + def extract_last_single_answer(question, reasoning, task): return extract_answer(reasoning, exhaust=False) + def extract_gsm_few_shot_cot_answer(question, reasoning, task): - if 'Q: ' in reasoning: + if "Q: " in reasoning: reasoning = reasoning.split("Q: ", 1)[0] - pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] + pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)] if pred: return pred[-1] else: return "[invalid]" + def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): - if '问题 ' in reasoning: + if "问题 " in reasoning: reasoning = reasoning.split("问题 ", 1)[0] - if '答案是' in reasoning: - ans = reasoning.split('答案是', 1)[1].strip() + if "答案是" in reasoning: + ans = reasoning.split("答案是", 1)[1].strip() ans = ans.split("\n")[0].strip() ans = [ans.strip("$")] else: - ans = ['placeholder'] + ans = ["placeholder"] return ans + def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): - if '问题 ' in reasoning: + if "问题 " in reasoning: reasoning = reasoning.split("问题 ", 1)[0] - if '答案是' in reasoning: - ans = reasoning.split('答案是', 1)[1].strip() + if "答案是" in reasoning: + ans = reasoning.split("答案是", 1)[1].strip() ans = ans.split("\n")[0].strip() else: - ans = 'placeholder' + ans = "placeholder" return ans + def extract_sat_few_shot_answer(question, reasoning, task): - if 'Problem:' in reasoning: + if "Problem:" in reasoning: reasoning = reasoning.split("Problem:", 1)[0] patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower()) if patt is not None: - return patt.group('ans').upper() - return 'placeholder' + return patt.group("ans").upper() + return "placeholder" + def extract_ocwcourses_few_shot_answer(question, reasoning, task): - if 'Problem:' in reasoning: + if "Problem:" in reasoning: reasoning = reasoning.split("Problem:", 1)[0] patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning) if patt is None: pred = "[invalid]" print(f"DEBUG >>>\n{reasoning}", flush=True) else: - pred = patt.group('ans') + pred = patt.group("ans") return pred + def extract_mmlu_stem(question, reasoning, task): - if 'Problem:' in reasoning: + if "Problem:" in reasoning: reasoning = reasoning.split("Problem:", 1)[0] return extract_sat_few_shot_answer(question, reasoning, task) + def extract_minif2f_isabelle(question, reasoning, task): - if 'Informal:' in reasoning: + if "Informal:" in reasoning: reasoning = reasoning.split("Informal:", 1)[0] return reasoning.strip() + def extract_cmath_few_shot_test(question, reasoning, task): - if '问题:' in reasoning: + if "问题:" in reasoning: reasoning = reasoning.split("问题:", 1)[0] - if '答案是' in reasoning: - ans = reasoning.split('答案是', 1)[1].strip() + if "答案是" in reasoning: + ans = reasoning.split("答案是", 1)[1].strip() ans = ans.split("\n")[0] ans = ans.strip(":") ans = ans.strip("。") try: - ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1] + ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1] except: print(f"DEBUG CMATH: {reasoning}", flush=True) ans = "[invalid]" diff --git a/examples/rl_gsm8k/deepseek_math_eval/eval_script.py b/examples/rl_gsm8k/deepseek_math_eval/eval_script.py index 501c173d..fe160e9a 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/eval_script.py +++ b/examples/rl_gsm8k/deepseek_math_eval/eval_script.py @@ -1,40 +1,47 @@ -import regex from copy import deepcopy + +import regex + from examples.rl_gsm8k.deepseek_math_eval.eval_utils import math_equal -from examples.rl_gsm8k.deepseek_math_eval.ocw_courses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin +from examples.rl_gsm8k.deepseek_math_eval.ocw_courses_eval_utils import ( + SymbolicMathMixin, + normalize_numeric, + normalize_symbolic_equation, + numeric_equality, +) -def is_correct(item, pred_key='prediction', prec=1e-3): + +def is_correct(item, pred_key="prediction", prec=1e-3): pred = item[pred_key] - ans = item['answer'] + ans = item["answer"] if isinstance(pred, list) and isinstance(ans, list): pred_matched = set() ans_matched = set() for i in range(len(pred)): for j in range(len(ans)): item_cpy = deepcopy(item) - item_cpy.update({ - pred_key: pred[i], - 'answer': ans[j] - }) + item_cpy.update({pred_key: pred[i], "answer": ans[j]}) if is_correct(item_cpy, pred_key=pred_key, prec=prec): pred_matched.add(i) ans_matched.add(j) - if item_cpy[pred_key] == '2,3,4': + if item_cpy[pred_key] == "2,3,4": print(item, flush=True) print("wtf", flush=True) return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) elif isinstance(pred, str) and isinstance(ans, str): - if '\\cup' in pred and '\\cup' in ans: + if "\\cup" in pred and "\\cup" in ans: item = deepcopy(item) - item.update({ - pred_key: pred.split('\\cup'), - 'answer': ans.split('\\cup'), - }) + item.update( + { + pred_key: pred.split("\\cup"), + "answer": ans.split("\\cup"), + } + ) return is_correct(item, pred_key=pred_key, prec=prec) else: label = False try: - label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec + label = abs(float(regex.sub(r",", "", str(pred))) - float(regex.sub(r",", "", str(ans)))) < prec except: pass label = label or (ans and pred == ans) or math_equal(pred, ans) @@ -43,11 +50,12 @@ def is_correct(item, pred_key='prediction', prec=1e-3): print(item, flush=True) raise NotImplementedError() -def eval_math(item, pred_key='prediction', prec=1e-3): + +def eval_math(item, pred_key="prediction", prec=1e-3): pred = item[pred_key] - if pred_key == 'program_output' and isinstance(pred, str): + if pred_key == "program_output" and isinstance(pred, str): pred = [pred] - ans = item['answer'] + ans = item["answer"] if isinstance(pred, list) and isinstance(ans, list): # for some questions in MATH, `reference` repeats answers _ans = [] @@ -61,81 +69,86 @@ def eval_math(item, pred_key='prediction', prec=1e-3): if a not in _pred: _pred.append(a) # some predictions mistakenly box non-answer strings - pred = _pred[-len(ans):] + pred = _pred[-len(ans) :] - item.update({ - pred_key: pred, - 'answer': ans - }) + item.update({pred_key: pred, "answer": ans}) return is_correct(item, pred_key=pred_key, prec=prec) -def eval_last_single_answer(item, pred_key='prediction', prec=1e-3): - for key in [pred_key, 'answer']: + +def eval_last_single_answer(item, pred_key="prediction", prec=1e-3): + for key in [pred_key, "answer"]: assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" return is_correct(item, pred_key=pred_key, prec=prec) -def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3): - if pred_key == 'program_output' and isinstance(item[pred_key], str): + +def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3): + if pred_key == "program_output" and isinstance(item[pred_key], str): item[pred_key] = [item[pred_key]] - for key in [pred_key, 'answer']: + for key in [pred_key, "answer"]: assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" pred = item[pred_key] - ans = item['answer'] + ans = item["answer"] _pred = [] for p in pred: p = p + ";" while p: left_brackets = 0 for i in range(len(p)): - if p[i] == ';' or (p[i] == ',' and left_brackets == 0): - _p, p = p[:i].strip(), p[i + 1:].strip() + if p[i] == ";" or (p[i] == "," and left_brackets == 0): + _p, p = p[:i].strip(), p[i + 1 :].strip() if _p not in _pred: _pred.append(_p) break - elif p[i] in '([{': + elif p[i] in "([{": left_brackets += 1 - elif p[i] in ')]}': + elif p[i] in ")]}": left_brackets -= 1 - pred = _pred[-len(ans):] + pred = _pred[-len(ans) :] if len(pred) == len(ans): for p, a in zip(pred, ans): - item.update({ - pred_key: p, - 'answer': a, - }) + item.update( + { + pred_key: p, + "answer": a, + } + ) if not is_correct(item, pred_key=pred_key, prec=prec): return False return True else: return False -def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3): - if pred_key == 'program_output' and isinstance(item[pred_key], str): + +def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3): + if pred_key == "program_output" and isinstance(item[pred_key], str): item[pred_key] = [item[pred_key]] pred_str = " ".join(item[pred_key]) - ans = item['answer'] + ans = item["answer"] tag = None idx = -1 - for t in 'ABCD': + for t in "ABCD": if t in pred_str and pred_str.index(t) > idx: tag = t idx = pred_str.index(t) return tag == ans -def eval_math_sat(item, pred_key='prediction', prec=1e-3): - for key in [pred_key, 'answer']: + +def eval_math_sat(item, pred_key="prediction", prec=1e-3): + for key in [pred_key, "answer"]: assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" - return item[pred_key].lower() == item['answer'].lower() + return item[pred_key].lower() == item["answer"].lower() -def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3): + +def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3): return eval_math_sat(item, pred_key=pred_key, prec=prec) -def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): + +def eval_ocwcourses(item, pred_key="prediction", prec=1e-3): INVALID_ANSWER = "[invalidanswer]" - for key in [pred_key, 'answer']: + for key in [pred_key, "answer"]: assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" pred = item[pred_key] - ans = item['answer'] + ans = item["answer"] try: float(ans) @@ -145,7 +158,7 @@ def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): except ValueError: if "=" in ans: normalize_fn = normalize_symbolic_equation - is_equiv = lambda x, y: x==y + is_equiv = lambda x, y: x == y answer_type = "equation" else: normalize_fn = SymbolicMathMixin().normalize_tex @@ -168,5 +181,6 @@ def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): return acc -def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3): + +def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3): return True diff --git a/examples/rl_gsm8k/deepseek_math_eval/eval_utils.py b/examples/rl_gsm8k/deepseek_math_eval/eval_utils.py index f0e1a1c8..0d3f7de7 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/eval_utils.py +++ b/examples/rl_gsm8k/deepseek_math_eval/eval_utils.py @@ -1,16 +1,17 @@ import multiprocessing +import re from math import isclose -import numpy as np -from typing import Union, Any, Dict +from typing import Any, Dict, Union -from sympy import simplify, N -from sympy.parsing.sympy_parser import parse_expr -from sympy.parsing.latex import parse_latex -import re +import numpy as np import regex +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr from examples.rl_gsm8k.deepseek_math_eval.answer_extraction import extract_answer, extract_program_output, strip_string + def extract_program(result: str, last_only=True): """ extract the program after "```python", and before "```" @@ -20,7 +21,7 @@ def extract_program(result: str, last_only=True): for line in result.split("\n"): if line.startswith("```python"): if last_only: - program = "" # only extract the last program + program = "" # only extract the last program else: program += "\n# ========\n" start = True @@ -32,38 +33,38 @@ def extract_program(result: str, last_only=True): def parse_ground_truth(example: Dict[str, Any], data_name): - if 'gt_cot' in example: - return example['gt_cot'], strip_string(example['gt']) + if "gt_cot" in example: + return example["gt_cot"], strip_string(example["gt"]) # parse ground truth - if data_name in ["math", 'ocw']: - gt_cot = example['solution'] + if data_name in ["math", "ocw"]: + gt_cot = example["solution"] gt_ans = extract_answer(gt_cot) elif data_name == "gsm8k": - gt_cot, gt_ans = example['answer'].split("####") + gt_cot, gt_ans = example["answer"].split("####") elif data_name == "gsm-hard": - gt_cot, gt_ans = example['code'], example['target'] + gt_cot, gt_ans = example["code"], example["target"] elif data_name == "svamp": - gt_cot, gt_ans = example['Equation'], example['Answer'] + gt_cot, gt_ans = example["Equation"], example["Answer"] elif data_name == "asdiv": - gt_cot = example['formula'] - gt_ans = re.sub(r"\(.*?\)", "", example['answer']) + gt_cot = example["formula"] + gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) elif data_name == "mawps": - gt_cot, gt_ans = None, example['target'] + gt_cot, gt_ans = None, example["target"] elif data_name == "tabmwp": - gt_cot = example['solution'] - gt_ans = example['answer'] - if example['ans_type'] in ['integer_number', 'decimal_number']: - if '/' in gt_ans: - gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) - elif ',' in gt_ans: - gt_ans = float(gt_ans.replace(',', '')) - elif '%' in gt_ans: - gt_ans = float(gt_ans.split('%')[0]) / 100 + gt_cot = example["solution"] + gt_ans = example["answer"] + if example["ans_type"] in ["integer_number", "decimal_number"]: + if "/" in gt_ans: + gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) + elif "," in gt_ans: + gt_ans = float(gt_ans.replace(",", "")) + elif "%" in gt_ans: + gt_ans = float(gt_ans.split("%")[0]) / 100 else: gt_ans = float(gt_ans) elif data_name == "bbh": - gt_cot, gt_ans = None, example['target'] + gt_cot, gt_ans = None, example["target"] else: raise NotImplementedError(data_name) # post process @@ -82,13 +83,13 @@ def parse_question(example, data_name): body = body + "." question = f'{body} {example["Question"].strip()}' elif data_name == "tabmwp": - title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" - question = f'Read the following table {title_str}and answer a question:\n' + title_str = f'regarding "{example["table_title"]}" ' if example["table_title"] else "" + question = f"Read the following table {title_str}and answer a question:\n" question += f'{example["table"]}\n{example["question"]}' - if example['choices']: + if example["choices"]: question += f' Please select from the following options: {example["choices"]}' else: - for key in ['question', 'problem', 'Question', 'input']: + for key in ["question", "problem", "Question", "input"]: if key in example: question = example[key] break @@ -97,7 +98,7 @@ def parse_question(example, data_name): def run_execute(executor, result, prompt_type, execute=False): - if not result or result == 'error': + if not result or result == "error": return None, None report = None @@ -115,13 +116,13 @@ def run_execute(executor, result, prompt_type, execute=False): def parse_digits(num): # format: 234.23 || 23% - num = regex.sub(',', '', str(num)) + num = regex.sub(",", "", str(num)) try: return float(num) except: - if num.endswith('%'): + if num.endswith("%"): num = num[:-1] - if num.endswith('\\'): + if num.endswith("\\"): num = num[:-1] try: return float(num) / 100 @@ -129,13 +130,14 @@ def parse_digits(num): pass return None + def is_digit(num): # paired with parse_digits return parse_digits(num) is not None def normalize_prediction(prediction): - try: # 1. numerical equal + try: # 1. numerical equal if is_digit(prediction): prediction = np.round(float(str(prediction).replace(",", "")), 6) return str(prediction) @@ -147,20 +149,24 @@ def normalize_prediction(prediction): ## deal with [], (), {} brackets = [] - while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): + while ( + prediction.startswith("[") + and prediction.endswith("]") + or (prediction.startswith("(") and prediction.endswith(")")) + ): bracket = prediction[0] prediction = prediction[1:-1] - if brackets and ',' in prediction: + if brackets and "," in prediction: pred_parts = [normalize_prediction(part) for part in prediction.split(",")] prediction = ",".join(pred_parts) if brackets: for b in reversed(brackets): - if b == '[': - prediction = '[' + prediction + ']' + if b == "[": + prediction = "[" + prediction + "]" else: - assert b == '(' - prediction = '(' + prediction + ')' + assert b == "(" + prediction = "(" + prediction + ")" def _parse(s): for f in [parse_latex, parse_expr]: @@ -172,18 +178,19 @@ def _parse(s): prediction = _parse(prediction) - for s in ['{', "}", "(", ")"]: + for s in ["{", "}", "(", ")"]: prediction = prediction.replace(s, "") return prediction -def math_equal(prediction: Union[bool, float, str], - reference: Union[float, str], - include_percentage: bool = True, - is_close: bool = True, - timeout: bool = False, - ) -> bool: +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, + timeout: bool = False, +) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal @@ -192,7 +199,7 @@ def math_equal(prediction: Union[bool, float, str], if str(prediction) == str(reference): return True - try: # 1. numerical equal + try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = parse_digits(prediction) reference = parse_digits(reference) @@ -222,24 +229,46 @@ def math_equal(prediction: Union[bool, float, str], reference = str(reference).strip() prediction = str(prediction).strip() - if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: + if ( + regex.match(r"(\(|\[).+(\)|\])", prediction) is not None + and regex.match(r"(\(|\[).+(\)|\])", reference) is not None + ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): - if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): + if all( + [math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))] + ): return True - if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ - (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): - pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] - ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] + if ( + (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) + and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) + and (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) + and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")) + ): + pred_lines = [ + line.strip() + for line in prediction[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\") + if line.strip() + ] + ref_lines = [ + line.strip() + for line in reference[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\") + if line.strip() + ] matched = True if len(pred_lines) == len(ref_lines): for pred_line, ref_line in zip(pred_lines, ref_lines): pred_parts = pred_line.split("&") ref_parts = ref_line.split("&") if len(pred_parts) == len(ref_parts): - if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): + if not all( + [ + math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) + for i in range(len(pred_parts)) + ] + ): matched = False break else: @@ -251,18 +280,18 @@ def math_equal(prediction: Union[bool, float, str], if matched: return True - if prediction.count('=') == 1 and reference.count('=') == 1: - pred = prediction.split('=') + if prediction.count("=") == 1 and reference.count("=") == 1: + pred = prediction.split("=") pred = f"{pred[0].strip()} - ({pred[1].strip()})" - ref = reference.split('=') + ref = reference.split("=") ref = f"{ref[0].strip()} - ({ref[1].strip()})" if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): return True - elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: - if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): + elif prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference: + if math_equal(prediction.split("=")[1], reference, include_percentage, is_close): return True - elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: - if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): + elif reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction: + if math_equal(prediction, reference.split("=")[1], include_percentage, is_close): return True # symbolic equal with sympy @@ -288,11 +317,12 @@ def _parse(s): except: pass return s + a = _parse(a) b = _parse(b) try: - if simplify(a-b) == 0: + if simplify(a - b) == 0: return True except: pass @@ -305,21 +335,21 @@ def _parse(s): return False -def symbolic_equal_process(a, b, output_queue): +def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) - output_queue.put(result) + output_queue.put(result) -def call_with_timeout(func, *args, timeout=1, **kwargs): - output_queue = multiprocessing.Queue() - process_args = args + (output_queue,) - process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) - process.start() - process.join(timeout) - - if process.is_alive(): +def call_with_timeout(func, *args, timeout=1, **kwargs): + output_queue = multiprocessing.Queue() + process_args = args + (output_queue,) + process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) + process.start() + process.join(timeout) + + if process.is_alive(): process.terminate() - process.join() - return False - + process.join() + return False + return output_queue.get() diff --git a/examples/rl_gsm8k/deepseek_math_eval/ocw_courses_eval_utils.py b/examples/rl_gsm8k/deepseek_math_eval/ocw_courses_eval_utils.py index 4e77598f..cd32795f 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/ocw_courses_eval_utils.py +++ b/examples/rl_gsm8k/deepseek_math_eval/ocw_courses_eval_utils.py @@ -1,13 +1,14 @@ import re +import signal + import numpy as np import sympy from sympy.core.sympify import SympifyError from sympy.parsing.latex import parse_latex -import signal - INVALID_ANSWER = "[invalidanswer]" + class timeout: def __init__(self, seconds=1, error_message="Timeout"): self.seconds = seconds @@ -23,6 +24,7 @@ def __enter__(self): def __exit__(self, type, value, traceback): signal.alarm(0) + def normalize_numeric(s): if s is None: return None @@ -66,6 +68,7 @@ def normalize_numeric(s): except: return INVALID_ANSWER + def numeric_equality(n1, n2, threshold=0.01): if n1 is None or n2 is None: return False @@ -74,6 +77,7 @@ def numeric_equality(n1, n2, threshold=0.01): else: return np.isclose(n1, n2) + def normalize_symbolic_equation(s): if not isinstance(s, str): return INVALID_ANSWER @@ -96,6 +100,7 @@ def normalize_symbolic_equation(s): except: return INVALID_ANSWER + class SymbolicMathMixin: """ Methods useful for parsing mathematical expressions from text and determining equivalence of expressions. @@ -205,7 +210,7 @@ def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic: with timeout(seconds=time_limit): parsed = parse_latex(text) except ( - # general error handling: there is a long tail of possible sympy/other + # general error handling: there is a long tail of possible sympy/other # errors we would like to catch Exception ) as e: @@ -223,9 +228,7 @@ def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: try: diff = x1 - x2 except (SympifyError, ValueError, TypeError) as e: - print( - f"Couldn't subtract {x1} and {x2} with exception {e}" - ) + print(f"Couldn't subtract {x1} and {x2} with exception {e}") return False try: @@ -236,7 +239,7 @@ def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: except (SympifyError, ValueError, TypeError) as e: print(f"Failed to simplify {x1}-{x2} with {e}") return False - except TimeoutError as e: + except TimeoutError: print(f"Timed out comparing {x1} and {x2}") return False except Exception as e: @@ -251,13 +254,13 @@ def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool: following the (Lewkowycz et al. 2022) methodology. """ if x1 == x2: - # don't resort to sympy if we have full string match, post-normalization + # don't resort to sympy if we have full string match, post-normalization return True - else: + else: return False parsed_x2 = self.parse_tex(x2) if not parsed_x2: - # if our reference fails to parse into a Sympy object, + # if our reference fails to parse into a Sympy object, # we forgo parsing + checking our generated answer. return False return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit) diff --git a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py index 29f9d25f..e72af3c0 100644 --- a/examples/rl_gsm8k/deepseek_math_eval/process_utils.py +++ b/examples/rl_gsm8k/deepseek_math_eval/process_utils.py @@ -4,160 +4,128 @@ from examples.rl_gsm8k.deepseek_math_eval.answer_extraction import extract_math_answer, strip_string from examples.rl_gsm8k.deepseek_math_eval.eval_utils import parse_ground_truth + def process_gsm8k_test(item): _, answer = parse_ground_truth(item, "gsm8k") - sample = { - 'dataset': 'gsm8k-cot', - 'task': item['question'], - 'answer': answer - } + sample = {"dataset": "gsm8k-cot", "task": item["question"], "answer": answer} return sample + def process_math_test(item): question = item["problem"] try: - answer = extract_math_answer(question, item['solution'], task="cot") - except Exception as e: + answer = extract_math_answer(question, item["solution"], task="cot") + except Exception: return - sample = { - "dataset": "math-cot", - "level": item["level"], - "type": item["type"], - "task": question, - "answer": answer - } + sample = {"dataset": "math-cot", "level": item["level"], "type": item["type"], "task": question, "answer": answer} return sample + def process_math_sat(item): - options = item['options'].strip() - assert 'A' == options[0] - options = '(' + options - for ch in 'BCDEFG': - if f' {ch}) ' in options: - options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) + options = item["options"].strip() + assert "A" == options[0] + options = "(" + options + for ch in "BCDEFG": + if f" {ch}) " in options: + options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" - messages = [ - {'role': 'user', 'content': question}, - {'role': 'assistant', 'content': item['Answer']} - ] + messages = [{"role": "user", "content": question}, {"role": "assistant", "content": item["Answer"]}] item = { - 'dataset': 'math_sat', - 'id': item['id'], - 'language': 'en', - 'messages': messages, - 'answer': item['Answer'], + "dataset": "math_sat", + "id": item["id"], + "language": "en", + "messages": messages, + "answer": item["Answer"], } yield item + def process_ocwcourses(item): messages = [ - {'role': 'user', 'content': item['problem'].strip()}, - {'role': 'assistant', 'content': item['solution'].strip()} + {"role": "user", "content": item["problem"].strip()}, + {"role": "assistant", "content": item["solution"].strip()}, ] - item = { - "dataset": "OCWCourses", - "id": item['id'], - "language": "en", - "messages": messages, - "answer": item['answer'] - } + item = {"dataset": "OCWCourses", "id": item["id"], "language": "en", "messages": messages, "answer": item["answer"]} yield item + def process_mmlu_stem(item): - options = item['options'] - for i, (label, option) in enumerate(zip('ABCD', options)): + options = item["options"] + for i, (label, option) in enumerate(zip("ABCD", options)): options[i] = f"({label}) {str(option).strip()}" options = ", ".join(options) question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" - messages = [ - {'role': 'user', 'content': question}, - {'role': 'assistant', 'content': item['answer']} - ] - item = { - "dataset": "MMLU-STEM", - "id": item['id'], - "language": "en", - "messages": messages, - "answer": item['answer'] - } + messages = [{"role": "user", "content": question}, {"role": "assistant", "content": item["answer"]}] + item = {"dataset": "MMLU-STEM", "id": item["id"], "language": "en", "messages": messages, "answer": item["answer"]} yield item + def process_mgsm_zh(item): - item['answer'] = item['answer'].replace(',', '') + item["answer"] = item["answer"].replace(",", "") yield item + def process_cmath(item): item = { - 'dataset': 'cmath', - 'id': item['id'], - 'grade': item['grade'], - 'reasoning_step': item['reasoning_step'], - 'messages': [ - {'role': 'user', 'content': item['question'].strip()}, - {'role': 'assistant', 'content': ''} - ], - 'answer': item['golden'].strip().replace(",", "") + "dataset": "cmath", + "id": item["id"], + "grade": item["grade"], + "reasoning_step": item["reasoning_step"], + "messages": [{"role": "user", "content": item["question"].strip()}, {"role": "assistant", "content": ""}], + "answer": item["golden"].strip().replace(",", ""), } yield item + def process_agieval_gaokao_math_cloze(item): item = { - 'dataset': 'agieval-gaokao-math-cloze', - 'id': item['id'], - 'messages': [ - {'role': 'user', 'content': item['question'].strip()}, - {'role': 'assistant', 'content': ''} - ], - 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")] + "dataset": "agieval-gaokao-math-cloze", + "id": item["id"], + "messages": [{"role": "user", "content": item["question"].strip()}, {"role": "assistant", "content": ""}], + "answer": [strip_string(ans) for ans in item["answer"].strip().split(";")], } yield item + def process_agieval_gaokao_mathqa(item): - question = item['question'].strip() + question = item["question"].strip() options = [] - for option in item['options']: + for option in item["options"]: option = option.strip() - assert option[0] == '(' - assert option[2] == ')' - assert option[1] in 'ABCD' + assert option[0] == "(" + assert option[2] == ")" + assert option[1] in "ABCD" option = f"{option[1]}: {option[3:].strip()}" options.append(option.strip()) question = f"{question}\n{options}" item = { - 'dataset': 'agieval-gaokao-mathqa', - 'id': item['id'], - 'messages': [ - {'role': 'user', 'content': question}, - {'role': 'assistant', 'content': ''} - ], - "answer": item['label'] + "dataset": "agieval-gaokao-mathqa", + "id": item["id"], + "messages": [{"role": "user", "content": question}, {"role": "assistant", "content": ""}], + "answer": item["label"], } yield item + def process_agieval_gaokao_mathqa_few_shot_cot_test(item): - question = item['question'].strip().rstrip('\\') - options = " ".join([opt.strip() for opt in item['options']]) + question = item["question"].strip().rstrip("\\") + options = " ".join([opt.strip() for opt in item["options"]]) question = f"{question}\n从以下选项中选择: {options}" item = { - 'dataset': 'agieval-gaokao-mathqa', - 'id': item['id'], - 'messages': [ - {'role': 'user', 'content': question}, - {'role': 'assistant', 'content': ''} - ], - "answer": item['label'] + "dataset": "agieval-gaokao-mathqa", + "id": item["id"], + "messages": [{"role": "user", "content": question}, {"role": "assistant", "content": ""}], + "answer": item["label"], } yield item + def process_minif2f_isabelle(item): question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}" item = { - 'dataset': 'minif2f-isabelle', - 'id': item['id'], - 'messages': [ - {'role': 'user', 'content': question}, - {'role': 'assistant', 'content': ''} - ], - "answer": "placeholder" + "dataset": "minif2f-isabelle", + "id": item["id"], + "messages": [{"role": "user", "content": question}, {"role": "assistant", "content": ""}], + "answer": "placeholder", } yield item diff --git a/examples/rl_gsm8k/gather_jsons.py b/examples/rl_gsm8k/gather_jsons.py index ccdb49bc..22ebda7b 100644 --- a/examples/rl_gsm8k/gather_jsons.py +++ b/examples/rl_gsm8k/gather_jsons.py @@ -1,8 +1,8 @@ # read json files from a folder, create new json with the same name that contains all the content -import sys import json import os +import sys def gather_jsons(folder: str): diff --git a/examples/rl_gsm8k/orchestrate_rl.py b/examples/rl_gsm8k/orchestrate_rl.py index 034552f1..581dfb6d 100644 --- a/examples/rl_gsm8k/orchestrate_rl.py +++ b/examples/rl_gsm8k/orchestrate_rl.py @@ -14,41 +14,26 @@ import hydra import numpy as np import torch +import wandb from datasets import load_dataset from omegaconf import DictConfig, OmegaConf from termcolor import colored from tqdm import tqdm -import wandb - -wandb.require("core") -from .cot_math_agent import ( - CoTMathAgent, - MathEnvironment, - RLMathTape, - Task, -) -from .deepseek_math_eval.answer_extraction import extract_last_single_answer, extract_math_answer -from .deepseek_math_eval.eval_script import eval_last_single_answer, eval_math -from .deepseek_math_eval.process_utils import process_gsm8k_test, process_math_test -from .utils import ( - VLLMServiceManager, - calculate_stats, - clean_up, - get_tokens_from_hf_tokenizer, - launch_training, - load_state, - save_state, - setup_logging, -) -from tapeagents.batch import batch_main_loop from tapeagents.core import LLMOutputParsingFailureAction, StepMetadata, TrainingText -from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb from tapeagents.finetune.data import MASKED_TOKEN_ID +from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb from tapeagents.llms import TrainableLLM -from tapeagents.observe import LLMCall, SQLiteWriterThread, retrieve_all_llm_calls from tapeagents.orchestrator import main_loop +from .cot_math_agent import CoTMathAgent, MathEnvironment, RLMathTape, Task +from .deepseek_math_eval.answer_extraction import extract_last_single_answer, extract_math_answer +from .deepseek_math_eval.eval_script import eval_last_single_answer, eval_math +from .deepseek_math_eval.process_utils import process_gsm8k_test, process_math_test +from .utils import VLLMServiceManager, calculate_stats, clean_up, launch_training, load_state, save_state, setup_logging + +wandb.require("core") + logger = logging.getLogger(__name__) @@ -469,7 +454,7 @@ def main(cfg: DictConfig): "execution_time/starting_assistantmodel_vllm": assistant_vllm_stats["starting_time"], "execution_time/starting_refmodel_vllm": refmodel_starting_time, } - logger.info(f"Logprob population stats:") + logger.info("Logprob population stats:") for stat_name, stat_value in logprob_stats.items(): logger.info(f"{stat_name}: {stat_value}") wandb.log(logprob_stats, step=state["iteration"]) diff --git a/examples/rl_gsm8k/utils.py b/examples/rl_gsm8k/utils.py index add63f84..33ac637a 100644 --- a/examples/rl_gsm8k/utils.py +++ b/examples/rl_gsm8k/utils.py @@ -1,27 +1,25 @@ import json import logging -import multiprocessing import os import shutil import subprocess +import threading import time from pathlib import Path -from typing import Dict, Optional, TextIO, Union, List -import threading +from typing import Dict, List, Optional, TextIO, Union + import numpy as np import psutil import requests import torch -import yaml -from omegaconf import DictConfig, ListConfig, OmegaConf from tenacity import retry, stop_after_attempt, wait_exponential -from examples.rl_gsm8k.run_finetune import run_finetuning_loop from transformers import PreTrainedTokenizer from tapeagents.llms import LLMOutput, Prompt logger = logging.getLogger(__name__) + def generate_cuda_device_strings(total_gpus: int, gpus_per_model: int) -> List[str]: """ Generate a list of CUDA device strings for assigning GPUs to models. @@ -40,6 +38,7 @@ def generate_cuda_device_strings(total_gpus: int, gpus_per_model: int) -> List[s cuda_device_strings.append(cuda_devices) return cuda_device_strings + class VLLMServiceManager: def __init__( self, @@ -70,9 +69,7 @@ def __init__( self.stats = {} def get_base_urls(self) -> list[str]: - return [ - f"http://127.0.0.1:{port}" for port in self.ports - ] + return [f"http://127.0.0.1:{port}" for port in self.ports] def _terminate_with_children(self, process_id: int) -> None: try: @@ -133,7 +130,9 @@ def _start_service(self) -> None: threads = [] - for i, device_number in enumerate(generate_cuda_device_strings(torch.cuda.device_count(), self.gpus_per_model_instance )): + for i, device_number in enumerate( + generate_cuda_device_strings(torch.cuda.device_count(), self.gpus_per_model_instance) + ): port = self.port + i # start_llm(device_number, port, assistant_procs, ports) thread = threading.Thread(target=self._start_llm, args=(device_number, port)) @@ -143,7 +142,6 @@ def _start_service(self) -> None: for thread in threads: thread.join() - @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=2, min=10)) def _start_llm(self, cuda_device, port): tensor_parallel_size = cuda_device.count(",") + 1 @@ -337,12 +335,7 @@ def calculate_stats(stats): } -def launch_training( - config_dir: str, - config_name: str, - accelerate_cfg_path: str, - use_deepspeed: bool = False -) -> None: +def launch_training(config_dir: str, config_name: str, accelerate_cfg_path: str, use_deepspeed: bool = False) -> None: """ Launch training process with proper GPU configuration and error handling. diff --git a/tapeagents/agent.py b/tapeagents/agent.py index f69dc634..e9405ecc 100644 --- a/tapeagents/agent.py +++ b/tapeagents/agent.py @@ -690,11 +690,13 @@ def _run_implementation(): else: raise ValueError("Agent can only generate steps or partial steps") n_iterations += 1 - updated_metadata = original_metadata.model_validate(dict( - parent_id=input_tape_id, - author=self.name, - n_added_steps=len(tape) - input_tape_length, - )) + updated_metadata = original_metadata.model_validate( + dict( + parent_id=input_tape_id, + author=self.name, + n_added_steps=len(tape) - input_tape_length, + ) + ) final_tape = tape.model_copy(update=dict(metadata=updated_metadata)) yield AgentEvent(final_tape=final_tape) diff --git a/tapeagents/batch.py b/tapeagents/batch.py index 91a2b8cc..9e0ce35f 100644 --- a/tapeagents/batch.py +++ b/tapeagents/batch.py @@ -34,7 +34,9 @@ def batch_main_loop( if not isinstance(environments, list): environments = [environments] * len(tapes) - def worker_func(input: tuple[TapeType, Environment], agent: Agent, max_loops: int, strict: bool) -> TapeType | Exception: + def worker_func( + input: tuple[TapeType, Environment], agent: Agent, max_loops: int, strict: bool + ) -> TapeType | Exception: start_tape, env = input try: result = main_loop(agent, start_tape, env, max_loops=max_loops).get_final_tape() diff --git a/tapeagents/core.py b/tapeagents/core.py index 91adebf9..e886ad16 100644 --- a/tapeagents/core.py +++ b/tapeagents/core.py @@ -438,5 +438,5 @@ class MakeObservation(Action, Generic[StepType]): def llm_dict(self) -> dict[str, Any]: """Dumps the step data as dictionary, excluding the metadata of the step itself and the metadata of the wrapped step""" obj = self.model_dump(exclude_none=True, exclude={"metadata"}) - del obj['new_observation']['metadata'] - return obj \ No newline at end of file + del obj["new_observation"]["metadata"] + return obj diff --git a/tapeagents/dialog_tape.py b/tapeagents/dialog_tape.py index 448fd046..973f409c 100644 --- a/tapeagents/dialog_tape.py +++ b/tapeagents/dialog_tape.py @@ -9,8 +9,6 @@ from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import BaseModel -from tapeagents.utils import image_base64_message - from .agent import Annotator from .core import ( Action, diff --git a/tapeagents/finetune/checkpoints.py b/tapeagents/finetune/checkpoints.py index 820198cd..14b2b9e3 100644 --- a/tapeagents/finetune/checkpoints.py +++ b/tapeagents/finetune/checkpoints.py @@ -334,13 +334,14 @@ def save_model_only( # convert to HF format on main process if accelerator.is_main_process: from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + logger.info("Converting DeepSpeed checkpoint to HF format") convert_zero_checkpoint_to_fp32_state_dict( checkpoint_dir=output_dir, output_dir=output_dir, tag=None, # will use 'global_step{step}' from DeepSpeed - safe_serialization=safe_serialization + safe_serialization=safe_serialization, ) # save model config diff --git a/tapeagents/finetune/finetune.py b/tapeagents/finetune/finetune.py index ac327028..18e0bf44 100644 --- a/tapeagents/finetune/finetune.py +++ b/tapeagents/finetune/finetune.py @@ -5,6 +5,7 @@ from collections import defaultdict from dataclasses import asdict from pathlib import Path + import numpy as np import torch from hydra import compose, initialize @@ -30,7 +31,7 @@ from .eval import evaluate_and_get_metrics from .logging_ import log_metrics, log_time, setup_logging from .optim import get_optimizer -from .rl import RLConfig, rl_step, make_rl_data_callback +from .rl import RLConfig, make_rl_data_callback, rl_step from .rl.utils import get_avg_rl_stats from .types import DataArgs, DataPartArgs, ModelClass, TrainingMetrics diff --git a/tapeagents/finetune/logging_.py b/tapeagents/finetune/logging_.py index 2ed67155..e376c9f0 100644 --- a/tapeagents/finetune/logging_.py +++ b/tapeagents/finetune/logging_.py @@ -8,12 +8,13 @@ import datasets import transformers import wandb -wandb.require("core") from omegaconf import DictConfig from wandb.sdk import wandb_run from .context import accelerator, logger +wandb.require("core") + def init_wandb( cfg: DictConfig, diff --git a/tapeagents/finetune/rl/__init__.py b/tapeagents/finetune/rl/__init__.py index 4b327c5d..f532bb04 100644 --- a/tapeagents/finetune/rl/__init__.py +++ b/tapeagents/finetune/rl/__init__.py @@ -58,6 +58,7 @@ class RLConfig(StepConfig): metadata={"help": "ReLU the weights before updating the model"}, ) + def make_rl_data_callback(args, current_dir, rl_config, model): if rl_config: populate_rl_data_ = partial( diff --git a/tapeagents/nodes.py b/tapeagents/nodes.py index 6f1d844c..157b3496 100644 --- a/tapeagents/nodes.py +++ b/tapeagents/nodes.py @@ -160,9 +160,13 @@ def tape_to_messages(self, tape: Tape, steps_description: str) -> list[dict]: Messages from tape are added with roles based on step type. If guidance exists, it's added as the final user message. """ - messages: list[dict] = [ - {"role": "system", "content": self.system_prompt}, - ] if self.system_prompt else [] + messages: list[dict] = ( + [ + {"role": "system", "content": self.system_prompt}, + ] + if self.system_prompt + else [] + ) if steps_description: messages.append({"role": "user", "content": steps_description}) for step in tape: diff --git a/tapeagents/observe.py b/tapeagents/observe.py index c7a8a0ab..d62ec1e2 100644 --- a/tapeagents/observe.py +++ b/tapeagents/observe.py @@ -10,6 +10,7 @@ import threading import time from typing import Callable, Optional, Type + from pydantic import BaseModel from .config import sqlite_db_path diff --git a/tapeagents/orchestrator.py b/tapeagents/orchestrator.py index 871f57be..855ef88e 100644 --- a/tapeagents/orchestrator.py +++ b/tapeagents/orchestrator.py @@ -268,7 +268,13 @@ def replay_tapes( for i, tape in enumerate(tapes): logger.debug(f"Tape {i}") try: - matched = replay_tape(agent, tape, env, start_tape=start_tapes[i] if start_tapes else None, reuse_observations=reuse_observations) + matched = replay_tape( + agent, + tape, + env, + start_tape=start_tapes[i] if start_tapes else None, + reuse_observations=reuse_observations, + ) if not matched: raise FatalError("Tape mismatch") ok += 1 diff --git a/tapeagents/parallel_processing.py b/tapeagents/parallel_processing.py index 4530e963..9ffbbb56 100644 --- a/tapeagents/parallel_processing.py +++ b/tapeagents/parallel_processing.py @@ -190,18 +190,19 @@ def producer(): if producer_thread.is_alive(): raise RuntimeError("Producer thread is still alive after timeout") + def eager_thread_pool_processor( stream: Iterable[InputType], worker_func: Callable[[InputType], OutputType], n_workers: int, initializer: None | Callable[..., None] = None, initargs: tuple[Any, ...] = (), - ordered: bool = False + ordered: bool = False, ) -> Generator[OutputType | Exception, None, None]: """ Processes a stream of items in a thread pool with eager processing. - Unlike lazy processing, this processor submits all tasks to the thread pool + Unlike lazy processing, this processor submits all tasks to the thread pool upfront and returns results as they complete. Args: @@ -220,18 +221,14 @@ def eager_thread_pool_processor( yield from (worker_func(item) for item in stream) return - with ThreadPoolExecutor( - max_workers=n_workers, - initializer=initializer, - initargs=initargs - ) as executor: + with ThreadPoolExecutor(max_workers=n_workers, initializer=initializer, initargs=initargs) as executor: # Submit all tasks upfront if ordered: # Preserve order of inputs futures = [] for item in stream: futures.append(executor.submit(worker_func, item)) - + # Yield results in original order for future in futures: try: @@ -241,13 +238,14 @@ def eager_thread_pool_processor( else: # Yield results as they complete (out of order) futures = [executor.submit(worker_func, item) for item in stream] - + for future in as_completed(futures): try: yield future.result() except Exception as e: yield e + def choose_processor(n_workers: int): return ( partial(eager_thread_pool_processor, n_workers=n_workers) diff --git a/tapeagents/steps.py b/tapeagents/steps.py index c7952e33..795e04a8 100644 --- a/tapeagents/steps.py +++ b/tapeagents/steps.py @@ -126,4 +126,4 @@ class UnknownStep(Step): class Annotation(Action): kind: Literal["annotation"] = "annotation" step: int - text: str \ No newline at end of file + text: str diff --git a/tests/add_sqlite_columns.py b/tests/add_sqlite_columns.py index 9b5852fb..0f8fd587 100644 --- a/tests/add_sqlite_columns.py +++ b/tests/add_sqlite_columns.py @@ -4,6 +4,7 @@ _llm_calls = "LLMCalls" + def add_columns_to_db(db_path): conn = sqlite3.connect(db_path) c = conn.cursor() @@ -16,6 +17,7 @@ def add_columns_to_db(db_path): conn.commit() conn.close() + def main(): dirname = sys.argv[1] for root, dirs, files in os.walk(dirname): @@ -25,6 +27,7 @@ def main(): print(db_path) add_columns_to_db(db_path) + if __name__ == "__main__": assert len(sys.argv) == 2, "Usage: python -m tests.add_sqlite_columns " - main() \ No newline at end of file + main() diff --git a/tests/test_llm_function.py b/tests/test_llm_function.py index ff4d6d32..802c6c3b 100644 --- a/tests/test_llm_function.py +++ b/tests/test_llm_function.py @@ -3,13 +3,12 @@ sys.path.append(str(Path(__file__).parent.parent.resolve())) -from examples.optimize.load_demos import load_agentic_rag_demos, load_rag_demos from examples.optimize.func_templates import make_answer_template, make_query_template +from examples.optimize.load_demos import load_agentic_rag_demos, load_rag_demos from tapeagents.dialog_tape import ToolResult, UserStep -from tapeagents.llm_function import Input, LLMFunctionTemplate, AssistantOutput, RationaleOutput +from tapeagents.llm_function import AssistantOutput, Input, LLMFunctionTemplate, RationaleOutput from tapeagents.utils import diff_strings - TEST_INPUT_STEP1 = UserStep( content="What is the nationality of the chef and restaurateur featured in Restaurant: Impossible?" )