Skip to content

Commit

Permalink
return token ids
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Jan 6, 2025
1 parent 30c9640 commit 647a491
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
1 change: 1 addition & 0 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _start_llm(self, cuda_device, port):
f"--model {self.model_name_or_path} "
f"--tensor-parallel-size {tensor_parallel_size} "
f"--port {port} "
"--return-tokens-as-token-ids "
"--disable-frontend-multiprocessing "
"--dtype bfloat16 "
f"{kwargs_str}"
Expand Down
1 change: 0 additions & 1 deletion tapeagents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ def __bool__(self) -> bool:

class TokenLogprob(BaseModel):
logprob: float
token: str
token_id: int
generated: int

Expand Down
13 changes: 4 additions & 9 deletions tapeagents/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,23 +544,16 @@ def make_llm_call_logprobs(
logprobs.append(
TokenLogprob(
token_id=id,
token=self.tokenizer.decode([id]),
logprob=0.0,
generated=0,
)
)
for logprob in completion_logprobs:
if logprob:
try:
token_id = self.tokenizer.encode(logprob["token"], add_special_tokens=False)
if not len(token_id):
# TODO: how should we handle empty tokens?
logger.debug(f"Empty token: {logprob}")
continue
logprobs.append(
TokenLogprob(
token_id=token_id[0],
token=logprob["token"],
token_id=logprob["token_id"],
logprob=logprob["logprob"],
generated=1,
)
Expand Down Expand Up @@ -706,8 +699,10 @@ def batch_generate(self, prompts: list[Prompt]) -> list[LLMCall]:
# /v1/completions returns logprobs in a format different to /v1/chat/completions
# Before calling self.process_logprobs, we need to convert the logprobs to a
# list of dicts format similar to /v1/chat/completions

#'tokens': ['token_id:1271', 'token_id:1505', '
chat_completion_logprobs = [
{"token": completion_logprobs["tokens"][j], "logprob": completion_logprobs["token_logprobs"][j]}
{"token_id": int(completion_logprobs["tokens"][j].split(":")[-1]), "logprob": completion_logprobs["token_logprobs"][j]}
for j in range(len(completion_logprobs["tokens"]))
]
logprobs = self.make_llm_call_logprobs(prompt_token_ids[i], chat_completion_logprobs)
Expand Down

0 comments on commit 647a491

Please sign in to comment.