diff --git a/examples/rl_gsm8k/utils.py b/examples/rl_gsm8k/utils.py index 8f43cb7e..7f8c950c 100644 --- a/examples/rl_gsm8k/utils.py +++ b/examples/rl_gsm8k/utils.py @@ -155,6 +155,7 @@ def _start_llm(self, cuda_device, port): f"--model {self.model_name_or_path} " f"--tensor-parallel-size {tensor_parallel_size} " f"--port {port} " + "--return-tokens-as-token-ids " "--disable-frontend-multiprocessing " "--dtype bfloat16 " f"{kwargs_str}" diff --git a/tapeagents/core.py b/tapeagents/core.py index 3ddb1fba..59e8d4f0 100644 --- a/tapeagents/core.py +++ b/tapeagents/core.py @@ -355,7 +355,6 @@ def __bool__(self) -> bool: class TokenLogprob(BaseModel): logprob: float - token: str token_id: int generated: int diff --git a/tapeagents/llms.py b/tapeagents/llms.py index 5a934b95..30401ff7 100644 --- a/tapeagents/llms.py +++ b/tapeagents/llms.py @@ -544,7 +544,6 @@ def make_llm_call_logprobs( logprobs.append( TokenLogprob( token_id=id, - token=self.tokenizer.decode([id]), logprob=0.0, generated=0, ) @@ -552,15 +551,9 @@ def make_llm_call_logprobs( for logprob in completion_logprobs: if logprob: try: - token_id = self.tokenizer.encode(logprob["token"], add_special_tokens=False) - if not len(token_id): - # TODO: how should we handle empty tokens? - logger.debug(f"Empty token: {logprob}") - continue logprobs.append( TokenLogprob( - token_id=token_id[0], - token=logprob["token"], + token_id=logprob["token_id"], logprob=logprob["logprob"], generated=1, ) @@ -706,8 +699,10 @@ def batch_generate(self, prompts: list[Prompt]) -> list[LLMCall]: # /v1/completions returns logprobs in a format different to /v1/chat/completions # Before calling self.process_logprobs, we need to convert the logprobs to a # list of dicts format similar to /v1/chat/completions + + #'tokens': ['token_id:1271', 'token_id:1505', ' chat_completion_logprobs = [ - {"token": completion_logprobs["tokens"][j], "logprob": completion_logprobs["token_logprobs"][j]} + {"token_id": int(completion_logprobs["tokens"][j].split(":")[-1]), "logprob": completion_logprobs["token_logprobs"][j]} for j in range(len(completion_logprobs["tokens"])) ] logprobs = self.make_llm_call_logprobs(prompt_token_ids[i], chat_completion_logprobs)