From 273e04c947e91ccf80bf276f900145bde3e7c4fb Mon Sep 17 00:00:00 2001 From: Yury Hayeu Date: Thu, 7 Mar 2024 12:50:54 +0100 Subject: [PATCH] Fix on store training --- diploma_thesis/agents/base/marl_agent.py | 2 +- diploma_thesis/agents/base/rl_agent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diploma_thesis/agents/base/marl_agent.py b/diploma_thesis/agents/base/marl_agent.py index 4d01640..f5ac05f 100644 --- a/diploma_thesis/agents/base/marl_agent.py +++ b/diploma_thesis/agents/base/marl_agent.py @@ -66,7 +66,7 @@ def train_step(self): @filter(lambda self, *args: self.phase == TrainingPhase()) def store(self, key: Key, sample: TrainingSample): - self.trainer[key].store(sample) + self.trainer[key].store(sample, self.model.policy) def loss_record(self): if self.keys is None: diff --git a/diploma_thesis/agents/base/rl_agent.py b/diploma_thesis/agents/base/rl_agent.py index 51f9335..734bac5 100644 --- a/diploma_thesis/agents/base/rl_agent.py +++ b/diploma_thesis/agents/base/rl_agent.py @@ -36,7 +36,7 @@ def train_step(self): @filter(lambda self, *args: self.phase != EvaluationPhase()) def store(self, key: Key, sample: TrainingSample): - self.trainer.store(sample) + self.trainer.store(sample, self.model.policy) def loss_record(self): return self.trainer.loss_record()