Skip to content

Commit

Permalink
Fix unwrapped of observation and action spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Dec 4, 2024
1 parent b39f316 commit 57d8fee
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
17 changes: 9 additions & 8 deletions morl_baselines/common/morl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,19 @@ def extract_env_info(self, env: Optional[gym.Env]) -> None:
self.env = env
if isinstance(self.env.observation_space, spaces.Discrete):
self.observation_shape = (1,)
self.observation_dim = self.env.unwrapped.observation_space.n
self.observation_dim = self.env.observation_space.n
else:
self.observation_shape = self.env.unwrapped.observation_space.shape
self.observation_dim = self.env.unwrapped.observation_space.shape[0]
self.observation_shape = self.env.observation_space.shape
self.observation_dim = self.env.observation_space.shape[0]

self.action_space = env.unwrapped.action_space
if isinstance(self.env.unwrapped.action_space, (spaces.Discrete, spaces.MultiBinary)):
self.action_space = env.action_space
if isinstance(self.env.action_space, (spaces.Discrete, spaces.MultiBinary)):
self.action_shape = (1,)
self.action_dim = self.env.unwrapped.action_space.n
self.action_dim = self.env.action_space.n
else:
self.action_shape = self.env.unwrapped.action_space.shape
self.action_dim = self.env.unwrapped.action_space.shape[0]
self.action_shape = self.env.action_space.shape
self.action_dim = self.env.action_space.shape[0]

self.reward_dim = self.env.unwrapped.reward_space.shape[0]

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def test_gpi_pd():


def test_gpi_pd_continuous_action():
env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)

agent = GPIPDContinuousAction(
env,
Expand Down Expand Up @@ -278,8 +278,8 @@ def test_pcn():


def test_capql():
env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)

agent = CAPQL(
env,
Expand Down

0 comments on commit 57d8fee

Please sign in to comment.