Skip to content

Commit

Permalink
Add support for wrapped inner VectorEnvs
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Aug 3, 2021
1 parent 8b24de6 commit 6795186
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def step_fn(actions):
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and return
# the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and done:
if not isinstance(env.unwrapped, VectorEnv) and done:
observation = env.reset()
return observation, reward, done, info

Expand Down Expand Up @@ -454,7 +454,7 @@ def step_fn(actions):
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and return
# the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and done:
if not isinstance(env.unwrapped, VectorEnv) and done:
observation = env.reset()
return observation, reward, done, info

Expand Down
18 changes: 11 additions & 7 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True
)

self._check_observation_spaces()
self.observations = create_empty_array(self.single_observation_space,
n=self.num_envs, fn=np.zeros)
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)

shape = (self.num_envs,)
if isinstance(self.envs[0].unwrapped, VectorEnv):
Expand Down Expand Up @@ -79,18 +80,22 @@ def reset_wait(self):

def step_async(self, actions):
self._actions = actions

def step_wait(self):
observations, infos = [], []
observations, rewards, dones, infos = [], [], [], []
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
observation, reward, done, info = env.step(action)
# Do nothing if the env is a VectorEnv, since it will automatically
# reset the envs that are done if needed in the 'step' method and
# return the initial observation instead of the final observation.
if not isinstance(env, VectorEnv) and self._dones[i]:
if not isinstance(env.unwrapped, VectorEnv) and done:
observation = env.reset()
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
self._rewards = np.stack(rewards)
self._dones = np.stack(dones)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
Expand All @@ -102,7 +107,6 @@ def step_wait(self):
infos,
)


def close_extras(self, **kwargs):
[env.close() for env in self.envs]

Expand Down

0 comments on commit 6795186

Please sign in to comment.