Skip to content

Commit

Permalink
Merge branch 'main' into docs_releases
Browse files Browse the repository at this point in the history
  • Loading branch information
ollmer committed Dec 2, 2024
2 parents 4168875 + 3491e18 commit 4aa74b6
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 25 deletions.
16 changes: 11 additions & 5 deletions examples/gaia_agent/scripts/tape_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import sys
from collections import defaultdict

from pydantic import TypeAdapter

from tapeagents.core import Action
from tapeagents.io import load_tapes
from tapeagents.io import load_legacy_tapes
from tapeagents.observe import retrieve_all_llm_calls
from tapeagents.renderers.camera_ready_renderer import CameraReadyRenderer
from tapeagents.tape_browser import TapeBrowser

from ..eval import calculate_accuracy, get_exp_config_dict, tape_correct
from ..steps import GaiaStep
from ..tape import GaiaTape

logging.basicConfig(level=logging.INFO)
Expand All @@ -25,7 +28,7 @@ def load_tapes(self, name: str) -> list:
_, fname, postfix = name.split("/", maxsplit=2)
tapes_path = os.path.join(self.tapes_folder, fname, "tapes")
try:
all_tapes: list[GaiaTape] = load_tapes(GaiaTape, tapes_path, file_extension=".json") # type: ignore
all_tapes: list[GaiaTape] = load_legacy_tapes(GaiaTape, tapes_path, step_class=TypeAdapter(GaiaStep)) # type: ignore
except Exception as e:
logger.error(f"Failed to load tapes from {tapes_path}: {e}")
return []
Expand Down Expand Up @@ -113,7 +116,7 @@ def get_tape_label(self, tape: GaiaTape) -> str:

for step in tape:
prompt_id = step.metadata.prompt_id
if prompt_id:
if prompt_id and prompt_id in self.llm_calls:
llm_calls_num += 1
tokens_num += (
self.llm_calls[prompt_id].prompt_length_tokens + self.llm_calls[prompt_id].output_length_tokens
Expand Down Expand Up @@ -146,8 +149,11 @@ def get_tape_files(self) -> list[str]:
for postfix in ["1", "2", "3", "all"]:
for r in raw_exps:
exp_dir = os.path.join(self.tapes_folder, r)
cfg = get_exp_config_dict(exp_dir)
set_name = cfg["split"]
try:
cfg = get_exp_config_dict(exp_dir)
set_name = cfg["split"]
except Exception:
set_name = "-"
exps.append(f"{set_name}/{r}/{postfix}")
return sorted(exps)

Expand Down
74 changes: 54 additions & 20 deletions tapeagents/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import yaml
from pydantic import TypeAdapter

from tapeagents.dialog_tape import AssistantStep

from .core import Tape

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,6 +117,30 @@ def save_json_tape(tape: Tape, tapes_dir: str, name: str = ""):
f.write(tape.model_dump_json(indent=4))


def load_tape_dicts(path: Path | str, file_extension: str = ".yaml") -> list[dict]:
if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if file_extension not in (".yaml", ".json"):
raise ValueError(f"Unsupported file extension: {file_extension}")
if os.path.isdir(path):
logger.info(f"Loading tapes from dir {path}")
paths = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(file_extension)])
else:
paths = [path]
file_extension = os.path.splitext(path)[-1]
tapes = []
for path in paths:
with open(path) as f:
if file_extension == ".yaml":
data = list(yaml.safe_load_all(f))
else:
data = json.load(f)
if not isinstance(data, list):
data = [data]
tapes.extend(data)
return tapes


def load_tapes(tape_class: Type | TypeAdapter, path: Path | str, file_extension: str = ".yaml") -> list[Tape]:
"""Load tapes from dir with YAML or JSON files.
Expand All @@ -140,26 +166,34 @@ def load_tapes(tape_class: Type | TypeAdapter, path: Path | str, file_extension:
tapes = load_tapes(tape_adapter, "configs/tapes", ".json")
```
"""
if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if file_extension not in (".yaml", ".json"):
raise ValueError(f"Unsupported file extension: {file_extension}")
if os.path.isdir(path):
logger.info(f"Loading tapes from dir {path}")
paths = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(file_extension)])
else:
paths = [path]
file_extension = os.path.splitext(path)[-1]
tapes = []
loader = tape_class.model_validate if isinstance(tape_class, Type) else tape_class.validate_python
data = load_tape_dicts(path, file_extension)
for tape_dict in data:
tape = loader(tape_dict)
tapes.append(tape)
return tapes


def load_legacy_tapes(tape_class: Type | TypeAdapter, path: Path | str, step_class: Type | TypeAdapter) -> list[Tape]:
tapes = []
for path in paths:
with open(path) as f:
if file_extension == ".yaml":
data = list(yaml.safe_load_all(f))
else:
data = json.load(f)
if not isinstance(data, list):
data = [data]
for tape in data:
tapes.append(loader(tape))
loader = tape_class.model_validate if isinstance(tape_class, Type) else tape_class.validate_python
data = load_tape_dicts(path, ".json")
for tape_dict in data:
try:
tape = loader(tape_dict)
except Exception:
step_dicts = tape_dict["steps"]
tape_dict["steps"] = []
tape = loader(tape_dict)
step_loader = step_class.model_validate if isinstance(step_class, Type) else step_class.validate_python
steps = []
for step_dict in step_dicts:
try:
steps.append(step_loader(step_dict))
except Exception as e:
logger.warning(f"Failed to load step: {e}")
steps.append(AssistantStep(content=json.dumps(step_dict, indent=2, ensure_ascii=False)))
tape.steps = steps
tapes.append(tape)
return tapes

0 comments on commit 4aa74b6

Please sign in to comment.