Skip to content

Commit

Permalink
run lint
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 22, 2025
1 parent 644cf78 commit 56a3b1d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions tapeagents/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def get_token_costs(self) -> dict:
"""
return {"input": 0, "output": 0}

def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False, count_tokens: bool = True) -> None | LLMCall:
def log_output(
self, prompt: Prompt, message: LLMOutput, cached: bool = False, count_tokens: bool = True
) -> None | LLMCall:
"""
Logs the output of an LLM (Language Model) call along with its metadata.
Expand All @@ -233,7 +235,7 @@ def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False, c
else:
# -1 is the default value of prompt and output length tokens when token counting is disabled
prompt_length_tokens = -1
output_length_tokens = -1
output_length_tokens = -1

llm_call = LLMCall(
timestamp=datetime.datetime.now().isoformat(),
Expand Down Expand Up @@ -735,7 +737,9 @@ def batch_generate(self, prompts: list[Prompt]) -> list[LLMCall]:
llm_call.output_length_tokens = len(chat_completion_logprobs)
self._stats["prompt_length_tokens"].append(llm_call.prompt_length_tokens)
self._stats["output_length_tokens"].append(llm_call.output_length_tokens)
assert llm_call.output_length_tokens <= self.parameters["max_tokens"], f"output_length_tokens: {llm_call.output_length_tokens}, max_tokens: {self.parameters['max_tokens']}"
assert (
llm_call.output_length_tokens <= self.parameters["max_tokens"]
), f"output_length_tokens: {llm_call.output_length_tokens}, max_tokens: {self.parameters['max_tokens']}"
else:
llm_call = self.log_output(prompts[i], output, count_tokens=True)
llm_call.logprobs = logprobs
Expand Down
2 changes: 1 addition & 1 deletion tapeagents/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def make_prompt(self, agent: Any, tape: Tape) -> Prompt:
cleaned_tape = self.prepare_tape(tape)
steps_description = self.get_steps_description(tape, agent)
messages = self.tape_to_messages(cleaned_tape, steps_description)
#TODO: benchmark counting the token for every single call
# TODO: benchmark counting the token for every single call
if self.trim_tape_when_too_long and agent.llm.count_tokens(messages) > (agent.llm.context_size - 500):
cleaned_tape = self.trim_tape(cleaned_tape)
messages = self.tape_to_messages(cleaned_tape, steps_description)
Expand Down

0 comments on commit 56a3b1d

Please sign in to comment.