From 0188e29201cc49915db8b8095865b44b052cf9ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Willem=20R=C3=B6pke?= Date: Sat, 6 Jan 2024 11:11:26 +0100 Subject: [PATCH] Fix bug where tolist was called on a float --- morl_baselines/multi_policy/pcn/pcn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/morl_baselines/multi_policy/pcn/pcn.py b/morl_baselines/multi_policy/pcn/pcn.py index 38af0b82..a3fd687a 100644 --- a/morl_baselines/multi_policy/pcn/pcn.py +++ b/morl_baselines/multi_policy/pcn/pcn.py @@ -389,7 +389,7 @@ def train( num_er_episodes: int = 20, num_step_episodes: int = 10, num_model_updates: int = 50, - max_return: np.ndarray = 100.0, + max_return: np.ndarray = None, max_buffer_size: int = 100, num_points_pf: int = 100, ): @@ -403,10 +403,11 @@ def train( num_er_episodes: number of episodes to fill experience replay buffer num_step_episodes: number of steps per episode num_model_updates: number of model updates per episode - max_return: maximum return for clipping desired return + max_return: maximum return for clipping desired return. When None, this will be set to 100 for all objectives. max_buffer_size: maximum buffer size num_points_pf: number of points to sample from pareto front for metrics calculation """ + max_return = max_return if max_return is not None else np.full(self.reward_dim, 100.0, dtype=np.float32) if self.log: self.register_additional_config( {