Skip to content

Commit

Permalink
do not count token using tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 22, 2025
1 parent b5f5643 commit 666c3d5
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions tapeagents/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_token_costs(self) -> dict:
"""
return {"input": 0, "output": 0}

def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False) -> None | LLMCall:
def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False, count_tokens: bool = True) -> None | LLMCall:
"""
Logs the output of an LLM (Language Model) call along with its metadata.
Expand All @@ -219,14 +219,21 @@ def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False) -
"""

start_log_output = time.time()
prompt_length_tokens = self.count_tokens(prompt.messages)
if message.content:
output_length_tokens = (
self.count_tokens(prompt.messages + [{"role": "assistant", "content": message.content}])
- prompt_length_tokens
)
if count_tokens:
prompt_length_tokens = self.count_tokens(prompt.messages)
if message.content:
output_length_tokens = (
self.count_tokens(prompt.messages + [{"role": "assistant", "content": message.content}])
- prompt_length_tokens
)
else:
output_length_tokens = 0
self._stats["prompt_length_tokens"].append(prompt_length_tokens)
self._stats["output_length_tokens"].append(output_length_tokens)
else:
output_length_tokens = 0
# -1 is the default value of prompt and output length tokens when token counting is disabled
prompt_length_tokens = -1
output_length_tokens = -1

llm_call = LLMCall(
timestamp=datetime.datetime.now().isoformat(),
Expand All @@ -237,8 +244,6 @@ def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False) -
cached=cached,
llm_info=self.get_info(),
)
self._stats["prompt_length_tokens"].append(prompt_length_tokens)
self._stats["output_length_tokens"].append(output_length_tokens)
token_costs = self.get_token_costs()
llm_call.cost = (
token_costs["input"] * llm_call.prompt_length_tokens + token_costs["output"] * llm_call.output_length_tokens
Expand Down Expand Up @@ -722,7 +727,14 @@ def batch_generate(self, prompts: list[Prompt]) -> list[LLMCall]:
logger.exception(f"Failed to parse llm response: {r}")
raise e
output = LLMOutput(content=content)
llm_call = self.log_output(prompts[i], output)
if logprobs:
llm_call = self.log_output(prompts[i], output, count_tokens=False)
llm_call.prompt_length_tokens = len(prompt_token_ids[i])
llm_call.output_length_tokens = len(logprobs)
self._stats["prompt_length_tokens"].append(llm_call.prompt_length_tokens)
self._stats["output_length_tokens"].append(llm_call.output_length_tokens)
else:
llm_call = self.log_output(prompts[i], output, count_tokens=True)
llm_call.logprobs = logprobs
result.append(llm_call)
self._stats["time_postprocess_llm_response"].append(time.time() - start_postprocess_time)
Expand Down

0 comments on commit 666c3d5

Please sign in to comment.