From 9afcb3a94a5b2611735aac370cd9230920264aed Mon Sep 17 00:00:00 2001 From: Lucas Alegre Date: Thu, 25 Jan 2024 13:29:11 -0300 Subject: [PATCH] Remove dropout at evaluation time on GPI-LS (#90) --- morl_baselines/multi_policy/gpi_pd/gpi_pd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/morl_baselines/multi_policy/gpi_pd/gpi_pd.py b/morl_baselines/multi_policy/gpi_pd/gpi_pd.py index 50153bf8..17b37147 100644 --- a/morl_baselines/multi_policy/gpi_pd/gpi_pd.py +++ b/morl_baselines/multi_policy/gpi_pd/gpi_pd.py @@ -588,10 +588,14 @@ def eval(self, obs: np.ndarray, w: np.ndarray) -> int: """Select an action for the given obs and weight vector.""" obs = th.as_tensor(obs).float().to(self.device) w = th.as_tensor(w).float().to(self.device) + for q_net in self.q_nets: + q_net.eval() if self.use_gpi: action = self.gpi_action(obs, w, include_w=False) else: action = self.max_action(obs, w) + for q_net in self.q_nets: + q_net.train() return action def _act(self, obs: th.Tensor, w: th.Tensor) -> int: