diff --git a/diploma_thesis/agents/base/marl_agent.py b/diploma_thesis/agents/base/marl_agent.py index 4d01640..f5ac05f 100644 --- a/diploma_thesis/agents/base/marl_agent.py +++ b/diploma_thesis/agents/base/marl_agent.py @@ -66,7 +66,7 @@ def train_step(self): @filter(lambda self, *args: self.phase == TrainingPhase()) def store(self, key: Key, sample: TrainingSample): - self.trainer[key].store(sample) + self.trainer[key].store(sample, self.model.policy) def loss_record(self): if self.keys is None: diff --git a/diploma_thesis/agents/base/rl_agent.py b/diploma_thesis/agents/base/rl_agent.py index 51f9335..734bac5 100644 --- a/diploma_thesis/agents/base/rl_agent.py +++ b/diploma_thesis/agents/base/rl_agent.py @@ -36,7 +36,7 @@ def train_step(self): @filter(lambda self, *args: self.phase != EvaluationPhase()) def store(self, key: Key, sample: TrainingSample): - self.trainer.store(sample) + self.trainer.store(sample, self.model.policy) def loss_record(self): return self.trainer.loss_record()