Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Jan 23, 2025
1 parent 8e6eed0 commit 025d4fa
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/intro_clean.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@
"metadata": {},
"outputs": [],
"source": [
"from tapeagents.tools.search import web_search_tool\n",
"from tapeagents.tools.web_search import web_search_tool\n",
"from tapeagents.tools.simple_browser import SimpleTextBrowser\n",
"\n",
"browser = SimpleTextBrowser()\n",
Expand Down
2 changes: 1 addition & 1 deletion intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3028,7 +3028,7 @@
},
"outputs": [],
"source": [
"from tapeagents.tools.search import web_search_tool\n",
"from tapeagents.tools.web_search import web_search_tool\n",
"from tapeagents.tools.simple_browser import SimpleTextBrowser\n",
"\n",
"browser = SimpleTextBrowser()\n",
Expand Down
2 changes: 1 addition & 1 deletion tapeagents/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _implementation():
if _REPLAY_SQLITE:
closest, score = closest_prompt(key, list(self._cache.keys()))
logger.error(
f"LLM cache miss, closest in cache has score {score:.3f}\nDIFF:\n{diff_strings(key, closest)}"
f"LLM cache miss, closest in cache has score {score:.3f}\nNEW:\n{key}\nCLOSEST OLD:\n{closest}\nDIFF:\n{diff_strings(key, closest)}"
)
raise ValueError(f"LLM cache miss not allowed. Prompt key: {key}")
toks = self.count_tokens(prompt.messages)
Expand Down
7 changes: 2 additions & 5 deletions tapeagents/renderers/camera_ready_renderer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ast
import json
import os
import re

import yaml

Expand All @@ -19,7 +18,7 @@
from tapeagents.observe import LLMCall
from tapeagents.renderers.basic import BasicRenderer
from tapeagents.tools.code_executor import PythonCodeAction
from tapeagents.tools.container_executor import CodeBlock
from tapeagents.tools.container_executor import ANSI_ESCAPE_REGEX, CodeBlock
from tapeagents.view import Broadcast, Call, Respond

YELLOW = "#ffffba"
Expand All @@ -31,8 +30,6 @@
GREEN = "#6edb8f"
BLUE = "#bae1ff"

ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")


class CameraReadyRenderer(BasicRenderer):
def __init__(self, show_metadata=False, render_agent_node=True, show_content=True, **kwargs):
Expand Down Expand Up @@ -157,7 +154,7 @@ def format_code_block(block: CodeBlock) -> str:
elif isinstance(step, CodeExecutionResult):
text = f"exit_code:{step.result.exit_code}\n" if step.result.exit_code else ""
text += f"{maybe_fold(step.result.output, 2000)}"
text = ansi_escape.sub("", text)
text = ANSI_ESCAPE_REGEX.sub("", text)
if step.result.exit_code == 0 and step.result.output_files:
for file in step.result.output_files:
text += render_image(file)
Expand Down
9 changes: 3 additions & 6 deletions tapeagents/tools/container_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@
from pydantic import BaseModel, Field
from typing_extensions import Self

from tapeagents.utils import Lock

logger = logging.getLogger(__name__)

ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")

ANSI_ESCAPE_REGEX = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
DEFAULT_CONTAINER = "tapeagents-code-exec"


Expand All @@ -50,7 +47,7 @@ def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -


__all__ = ("ContainerExecutor",)
lock = Lock("container_executor")

DEFAULT_EXECUTION_POLICY = {
"bash": True,
"shell": True,
Expand Down Expand Up @@ -326,7 +323,7 @@ def execute_code_in_container(
raise e
assert isinstance(output, bytes)
output = output.decode("utf-8")
output = ansi_escape.sub("", output)
output = ANSI_ESCAPE_REGEX.sub("", output)
output = output.replace(filename, f"code.{code_block.language.lower()}")
if exit_code == 124:
output += "\n" + "Timeout"
Expand Down
12 changes: 10 additions & 2 deletions tapeagents/tools/stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from tapeagents.tools.tool_cache import cached_tool


@cached_tool
def get_stock_ticker(company_name: str) -> str:
"""Get company stock ticker from its name."""
return _get_stock_ticker(company_name)


@cached_tool
def _get_stock_ticker(company_name: str) -> str:
yfinance = "https://query2.finance.yahoo.com/v1/finance/search"
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"
params = {"q": company_name, "quotes_count": 1, "country": "United States"}
Expand All @@ -19,7 +23,6 @@ def get_stock_ticker(company_name: str) -> str:
return company_code


@cached_tool
def get_stock_data(symbol: str, start_date: str, end_date: str) -> list[tuple]:
"""Get stock proces for a given symbol and date range.
Expand All @@ -31,6 +34,11 @@ def get_stock_data(symbol: str, start_date: str, end_date: str) -> list[tuple]:
Returns:
(list[tuple]): List of tuples, each tuple contains a 'YYYY-MM-DD' date and the stock price.
"""
return _get_stock_data(symbol, start_date, end_date)


@cached_tool
def _get_stock_data(symbol: str, start_date: str, end_date: str):
symbol = symbol.upper()
# parse timestamps using datetime
start_timestamp = int(datetime.datetime.strptime(start_date, "%Y-%m-%d").timestamp())
Expand Down
2 changes: 0 additions & 2 deletions tapeagents/tools/tool_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def wrapper(*args, **kwargs):
add_to_cache(fn_name, args, kwargs, result)
return result

wrapper.__name__ = tool_fn.__name__
wrapper.__doc__ = tool_fn.__doc__
return wrapper


Expand Down
6 changes: 5 additions & 1 deletion tapeagents/tools/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
logger = logging.getLogger(__name__)


@cached_tool
def web_search_tool(query: str, max_results: int = 5, retry_pause: int = 5, attempts: int = 3) -> list[dict]:
"""
Search the web for a given query, return a list of search result dictionaries.
"""
return _web_search(query, max_results=max_results, retry_pause=retry_pause, attempts=attempts)


@cached_tool
def _web_search(query: str, max_results: int = 5, retry_pause: int = 5, attempts: int = 3) -> list[dict]:
try:
results = web_search(query, max_results=max_results, retry_pause=retry_pause, attempts=attempts)
except Exception as e:
Expand Down

0 comments on commit 025d4fa

Please sign in to comment.