Skip to content

Commit

Permalink
apply lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jpt-sn committed Jan 10, 2025
1 parent f0c8448 commit cb77e17
Show file tree
Hide file tree
Showing 35 changed files with 400 additions and 378 deletions.
2 changes: 1 addition & 1 deletion examples/continue_tapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tapeagents.batch import generate_tapes
from tapeagents.dialog_tape import AssistantStep, DialogTape, SystemStep, UserStep
from tapeagents.environment import EmptyEnvironment
from tapeagents.llms import TrainableLLM, LLM
from tapeagents.llms import LLM, TrainableLLM

from .llama_agent import LLAMAChatBot
from .llama_user import LLAMAUserModel
Expand Down
2 changes: 1 addition & 1 deletion examples/data_science/tape_browser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
import sys
from pathlib import Path

from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
from tapeagents.tape_browser import TapeBrowser
Expand Down
2 changes: 1 addition & 1 deletion examples/delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tapeagents.agent import Agent, AgentEvent
from tapeagents.core import Action, Prompt, Tape, Thought
from tapeagents.llms import TrainableLLM, LLM, LLMStream
from tapeagents.llms import LLM, LLMStream, TrainableLLM

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

Expand Down
1 change: 0 additions & 1 deletion examples/form_filler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

from pydantic import Field


FunctionName: TypeAlias = Annotated[str, Field(description="The name of a function.")]
ParameterName: TypeAlias = Annotated[str, Field(description="The name of a function parameter.")]
10 changes: 5 additions & 5 deletions examples/form_filler/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@

from typing import Any, Type

from . import steps


def render_chat_template(messages: list[dict[str, Any]], context: dict[str, Any]) -> list[dict[str, Any]]:
formatted_chat = []
for message in messages:
formatted_content = message['content'].format(**context)
formatted_content = message["content"].format(**context)
formatted_chat.append(
dict(
role=message['role'],
role=message["role"],
content=formatted_content,
)
)
Expand All @@ -31,15 +31,15 @@ def sanitize_json_completion(completion: str) -> str:
clean_lines = []
for line in lines:
line = line.replace("\\", "") # remove all backslashes
line = ' '.join(line.split()) # remove all extra spaces
line = " ".join(line.split()) # remove all extra spaces
if line.startswith("```"):
tiks_counter += 1
if tiks_counter == 1:
clean_lines = []
elif tiks_counter == 2:
break
continue
elif line.startswith("[") or line.startswith("{"): # detected start of the json section
elif line.startswith("[") or line.startswith("{"): # detected start of the json section
if not opened:
opened = True
clean_lines = []
Expand Down
10 changes: 5 additions & 5 deletions examples/gaia_agent/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def solve_task(
retries: int = 3,
max_loops: int = 50,
) -> Generator[GaiaTape, None, None]:
"""Solve GAIA task.
"""Solve GAIA task.
This function is a generator that yields intermediate tapes during the solving process.
The last tape will contain the agent's response.
"""
start_steps = env.task_to_observations(task)
solved = None
Expand All @@ -119,8 +119,8 @@ def solve_task(
if partial_tape := (event.agent_tape or event.env_tape):
tape = partial_tape
tape.metadata = GaiaMetadata.model_validate(
tape.metadata.model_dump() | {"task": task, "level": level}
)
tape.metadata.model_dump() | {"task": task, "level": level}
)
yield tape
if n_search_repetitions(tape) >= 3:
break
Expand Down
4 changes: 1 addition & 3 deletions examples/gaia_agent/scripts/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def main(cfg: DictConfig) -> None:
env = GaiaEnvironment(vision_lm=llm, code_sandbox=code_sandbox)
agent = GaiaAgent.create(llm, **cfg.agent)
tape = GaiaTape(steps=env.task_to_observations(task))
tape.metadata = GaiaMetadata.model_validate(
tape.metadata.model_dump() | {"task": task, "level": cfg.level}
)
tape.metadata = GaiaMetadata.model_validate(tape.metadata.model_dump() | {"task": task, "level": cfg.level})
step_count = 0
for event in main_loop(agent, tape, env, max_loops=50):
if event.agent_event and event.agent_event.step:
Expand Down
4 changes: 2 additions & 2 deletions examples/gaia_agent/scripts/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def main(cfg: DictConfig) -> None:
env = GaiaEnvironment(vision_lm=llm, attachment_dir=attachment_dir)
agent = GaiaAgent.create(llm, **cfg.agent)
content = "How many calories in 2 teaspoons of hummus"
if cfg.studio.tape:
if cfg.studio.tape:
tape = load_tapes(GaiaTape, cfg.studio.tape, ".json")[0]
else:
# Uncomment the following line to test video question
# content = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?"
tape = GaiaTape(steps=[GaiaQuestion(content=content)])
tape = GaiaTape(steps=[GaiaQuestion(content=content)])
Studio(agent, tape, CameraReadyRenderer(), env).launch(server_name="0.0.0.0", static_dir=attachment_dir)


Expand Down
2 changes: 1 addition & 1 deletion examples/gsm8k_tuning/math_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from tapeagents.agent import Agent
from tapeagents.core import (
Action,
LLMOutputParsingFailureAction,
FinalStep,
LLMOutputParsingFailureAction,
Observation,
SetNextNode,
Tape,
Expand Down
2 changes: 1 addition & 1 deletion examples/optimize/func_templates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tapeagents.dialog_tape import ToolResult
from tapeagents.llm_function import Input, LLMFunctionTemplate, AssistantOutput, RationaleOutput, ToolCallOutput
from tapeagents.llm_function import AssistantOutput, Input, LLMFunctionTemplate, RationaleOutput, ToolCallOutput


def render_contexts(contexts: list[str]) -> str:
Expand Down
1 change: 0 additions & 1 deletion examples/optimize/load_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
UserStep,
)


res_dir = pathlib.Path(__file__).parent.parent.resolve() / "res"


Expand Down
4 changes: 2 additions & 2 deletions examples/rl_gsm8k/browse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from pathlib import Path
import sys
from pathlib import Path

from examples.rl_gsm8k.cot_math_agent import MathTape
from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
from tapeagents.tape_browser import TapeBrowser
from examples.rl_gsm8k.cot_math_agent import MathTape

# comment this code out if loading the prompt and completions takes too long for you
tape_dir = Path(sys.argv[1])
Expand Down
104 changes: 60 additions & 44 deletions examples/rl_gsm8k/deepseek_math_eval/answer_extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re

import regex


def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
Expand Down Expand Up @@ -129,7 +131,7 @@ def strip_string(string):
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")

# and
# and
# string = string.replace("and", "")
string = string.replace("\\mathbf", "")
string = string.replace("\\mathrm", "")
Expand All @@ -139,8 +141,8 @@ def strip_string(string):

# quote
string.replace("'", "")
string.replace("\"", "")
string.replace('"', "")

# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
Expand Down Expand Up @@ -174,57 +176,60 @@ def strip_string(string):

return string


def extract_boxed_answers(text):
answers = []
for piece in text.split('boxed{')[1:]:
for piece in text.split("boxed{")[1:]:
n = 0
for i in range(len(piece)):
if piece[i] == '{':
if piece[i] == "{":
n += 1
elif piece[i] == '}':
elif piece[i] == "}":
n -= 1
if n < 0:
if i + 1 < len(piece) and piece[i + 1] == '%':
if i + 1 < len(piece) and piece[i + 1] == "%":
answers.append(piece[: i + 1])
else:
answers.append(piece[:i])
break
return answers


def extract_program_output(pred_str):
"""
extract output between the last ```output\n...\n```
"""
if "```output" not in pred_str:
return ""
if '```output' in pred_str:
pred_str = pred_str.split('```output')[-1]
if '```' in pred_str:
pred_str = pred_str.split('```')[0]
if "```output" in pred_str:
pred_str = pred_str.split("```output")[-1]
if "```" in pred_str:
pred_str = pred_str.split("```")[0]
output = pred_str.strip()
return output


def extract_answer(pred_str, exhaust=False):
pred = []
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
tmp = pred_str.split('final answer is $', 1)[1]
pred = [tmp.split('$. I hope', 1)[0].strip()]
elif 'boxed' in pred_str:
if "final answer is $" in pred_str and "$. I hope" in pred_str:
tmp = pred_str.split("final answer is $", 1)[1]
pred = [tmp.split("$. I hope", 1)[0].strip()]
elif "boxed" in pred_str:
pred = extract_boxed_answers(pred_str)
elif ('he answer is' in pred_str):
pred = [pred_str.split('he answer is')[-1].strip()]
elif "he answer is" in pred_str:
pred = [pred_str.split("he answer is")[-1].strip()]
else:
program_output = extract_program_output(pred_str)
if program_output != "":
# fall back to program
pred.append(program_output)
else: # use the last number
pattern = '-?\d*\.?\d+'
else: # use the last number
pattern = "-?\d*\.?\d+"
ans = re.findall(pattern, pred_str.replace(",", ""))
if(len(ans) >= 1):
if len(ans) >= 1:
ans = ans[-1]
else:
ans = ''
ans = ""
if ans:
pred.append(ans)

Expand All @@ -242,94 +247,105 @@ def extract_answer(pred_str, exhaust=False):
else:
return _pred[-1] if _pred else ""


def extract_math_answer(question, reasoning, task):
answer = []
for ans in extract_answer(reasoning, exhaust=True):
if 'separated by commas' in question and all(ch not in ans for ch in '()[]'):
if "separated by commas" in question and all(ch not in ans for ch in "()[]"):
answer.extend([a.strip() for a in ans.split(",")])
elif regex.search(r"\\text\{\s*and\s*\}", ans):
answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")])
else:
answer.append(ans.strip())
return answer


def extract_math_few_shot_cot_answer(question, reasoning, task):
if 'Problem:' in reasoning:
if "Problem:" in reasoning:
reasoning = reasoning.split("Problem:", 1)[0]
return extract_math_answer(question, reasoning, task)


def extract_last_single_answer(question, reasoning, task):
return extract_answer(reasoning, exhaust=False)


def extract_gsm_few_shot_cot_answer(question, reasoning, task):
if 'Q: ' in reasoning:
if "Q: " in reasoning:
reasoning = reasoning.split("Q: ", 1)[0]
pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)]
pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)]
if pred:
return pred[-1]
else:
return "[invalid]"


def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task):
if '问题 ' in reasoning:
if "问题 " in reasoning:
reasoning = reasoning.split("问题 ", 1)[0]
if '答案是' in reasoning:
ans = reasoning.split('答案是', 1)[1].strip()
if "答案是" in reasoning:
ans = reasoning.split("答案是", 1)[1].strip()
ans = ans.split("\n")[0].strip()
ans = [ans.strip("$")]
else:
ans = ['placeholder']
ans = ["placeholder"]
return ans


def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task):
if '问题 ' in reasoning:
if "问题 " in reasoning:
reasoning = reasoning.split("问题 ", 1)[0]
if '答案是' in reasoning:
ans = reasoning.split('答案是', 1)[1].strip()
if "答案是" in reasoning:
ans = reasoning.split("答案是", 1)[1].strip()
ans = ans.split("\n")[0].strip()
else:
ans = 'placeholder'
ans = "placeholder"
return ans


def extract_sat_few_shot_answer(question, reasoning, task):
if 'Problem:' in reasoning:
if "Problem:" in reasoning:
reasoning = reasoning.split("Problem:", 1)[0]
patt = regex.search(r"the final answer is \(?(?P<ans>[abcd])\)?", reasoning.lower())
if patt is not None:
return patt.group('ans').upper()
return 'placeholder'
return patt.group("ans").upper()
return "placeholder"


def extract_ocwcourses_few_shot_answer(question, reasoning, task):
if 'Problem:' in reasoning:
if "Problem:" in reasoning:
reasoning = reasoning.split("Problem:", 1)[0]
patt = regex.search(r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning)
if patt is None:
pred = "[invalid]"
print(f"DEBUG >>>\n{reasoning}", flush=True)
else:
pred = patt.group('ans')
pred = patt.group("ans")
return pred


def extract_mmlu_stem(question, reasoning, task):
if 'Problem:' in reasoning:
if "Problem:" in reasoning:
reasoning = reasoning.split("Problem:", 1)[0]
return extract_sat_few_shot_answer(question, reasoning, task)


def extract_minif2f_isabelle(question, reasoning, task):
if 'Informal:' in reasoning:
if "Informal:" in reasoning:
reasoning = reasoning.split("Informal:", 1)[0]
return reasoning.strip()


def extract_cmath_few_shot_test(question, reasoning, task):
if '问题:' in reasoning:
if "问题:" in reasoning:
reasoning = reasoning.split("问题:", 1)[0]
if '答案是' in reasoning:
ans = reasoning.split('答案是', 1)[1].strip()
if "答案是" in reasoning:
ans = reasoning.split("答案是", 1)[1].strip()
ans = ans.split("\n")[0]
ans = ans.strip(":")
ans = ans.strip("。")
try:
ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1]
ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1]
except:
print(f"DEBUG CMATH: {reasoning}", flush=True)
ans = "[invalid]"
Expand Down
Loading

0 comments on commit cb77e17

Please sign in to comment.