diff --git a/tapeagents/agent.py b/tapeagents/agent.py index 321982b2..04bb6f89 100644 --- a/tapeagents/agent.py +++ b/tapeagents/agent.py @@ -711,7 +711,6 @@ def run_batch(self: Agent[TapeType], tapes: list[TapeType]) -> list[Tape]: if not isinstance(self.llm, TrainableLLM): raise NotImplementedError("For run_agent_batch the LLM must be TrainableLLM") original_tapes = list(tapes) - parent_ids = [tape.metadata.id for tape in tapes] n_iterations = 0 active_indices = set(range(len(tapes))) while n_iterations < self.max_iterations: @@ -737,7 +736,7 @@ def run_batch(self: Agent[TapeType], tapes: list[TapeType]) -> list[Tape]: n_iterations += 1 for i in range(len(tapes)): updated_metadata = original_tapes[i].metadata.model_validate(dict( - parent_id=parent_ids[i], + parent_id=original_tapes[i].metadata.id, author=self.name, n_added_steps=len(tapes[i]) - len(original_tapes[i]) ))