-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds optional gym wrapper UnityEnv to use as python interfaces to Unity environments.
- Loading branch information
Showing
13 changed files
with
692 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,5 +78,8 @@ python/summaries | |
# VSCode hidden files | ||
*.vscode/ | ||
|
||
.DS_Store | ||
.ipynb_checkpoints | ||
|
||
# pytest cache | ||
*.pytest_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Unity ML-Agents Gym Wrapper | ||
|
||
A common way in which machine learning researchers interact with simulation environments is via a wrapper provided by OpenAI called `gym`. For more information on the gym interface, see [here](https://github.com/openai/gym). | ||
|
||
We provide a a gym wrapper, and instructions for using it with existing machine learning algorithms which utilize gyms. Both wrappers provide interfaces on top of our `UnityEnvironment` class, which is the default way of interfacing with a Unity environment via Python. | ||
|
||
## Installation | ||
|
||
The gym wrapper can be installed using: | ||
|
||
``` | ||
pip install gym_unity | ||
``` | ||
|
||
or by running the following from the `/gym-unity` directory of the repository: | ||
|
||
``` | ||
pip install . | ||
``` | ||
|
||
|
||
## Using the Gym Wrapper | ||
The gym interface is available from `gym_unity.envs`. To launch an environmnent from the root of the project repository use: | ||
|
||
```python | ||
from gym_unity.envs import UnityEnv | ||
|
||
env = UnityEnv(environment_filename, worker_id, default_visual, multiagent) | ||
``` | ||
|
||
* `environment_filename` refers to the path to the Unity environment. | ||
* `worker_id` refers to the port to use for communication with the environment. Defaults to `0`. | ||
* `use_visual` refers to whether to use visual observations (True) or vector observations (False) as the default observation provided by the `reset` and `step` functions. Defaults to `False`. | ||
* `multiagent` refers to whether you intent to launch an environment which contains more than one agent. Defaults to `False`. | ||
|
||
The returned environment `env` will function as a gym. | ||
|
||
For more on using the gym interface, see our [Jupyter Notebook tutorial](../python/notebooks/getting-started-gym.ipynb). | ||
|
||
|
||
## Limitation | ||
|
||
* It is only possible to use an environment with a single Brain. | ||
* By default the first visual observation is provided as the `observation`, if present. Otherwise vector observations are provided. | ||
* All `BrainInfo` output from the environment can still be accessed from the `info` provided by `env.step(action)`. | ||
* Stacked vector observations are not supported. | ||
* Environment registration for use with `gym.make()` is currently not supported. | ||
|
||
|
||
## Running OpenAI Baselines Algorithms | ||
|
||
OpenAI provides a set of open-source maintained and tested Reinforcement Learning algorithms called the [Baselines](https://github.com/openai/baselines). | ||
|
||
Using the provided Gym wrapper, it is possible to train ML-Agents environments using these algorithms. This requires the creation of custom training scripts to launch each algorithm. In most cases these scripts can be created by making slightly modifications to the ones provided for Atari and Mujoco environments. | ||
|
||
### Example - DQN Baseline | ||
|
||
In order to train an agent to play the `GridWorld` environment using the Baselines DQN algorithm, create a file called `train_unity.py` within the `baselines/deepq/experiments` subfolder of the baselines repository. This file will be a modification of the `run_atari.py` file within the same folder. Then create and `/envs/` directory within the repository, and build the GridWorld environment to that directory. For more information on building Unity environments, see [here](../docs/Learning-Environment-Executable.md). Add the following code to the `train_unity.py` file: | ||
|
||
``` | ||
import gym | ||
from baselines import deepq | ||
from gym_unity.envs import UnityEnv | ||
def main(): | ||
env = UnityEnv("./envs/GridWorld", 0, use_visual=True) | ||
model = deepq.models.cnn_to_mlp( | ||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], | ||
hiddens=[256], | ||
dueling=True, | ||
) | ||
act = deepq.learn( | ||
env, | ||
q_func=model, | ||
lr=1e-3, | ||
max_timesteps=100000, | ||
buffer_size=50000, | ||
exploration_fraction=0.1, | ||
exploration_final_eps=0.02, | ||
print_freq=10, | ||
) | ||
print("Saving model to unity_model.pkl") | ||
act.save("unity_model.pkl") | ||
if __name__ == '__main__': | ||
main() | ||
``` | ||
|
||
|
||
To start the training process, run the following from the root of the baselines repository: | ||
|
||
``` | ||
python -m baselines.deepq.experiments.train_unity | ||
``` | ||
|
||
### Other Algorithms | ||
|
||
Other algorithms in the Baselines repository can be run using scripts similar to the example provided above. In most cases, the primary changes needed to use a Unity environment are to import `UnityEnv`, and to replace the environment creation code, typically `gym.make()`, with a call to `UnityEnv(env_path)` passing the environment binary path. | ||
|
||
A typical rule of thumb is that for vision-based environments, modification should be done to Atari training scripts, and for vector observation environments, modification should be done to Mujoco scripts. | ||
|
||
Some algorithms will make use of `make_atari_env()` or `make_mujoco_env()` functions. These are defined in `baselines/common/cmd_util.py`. In order to use Unity environments for these algorithms, add the following import statement and function to `cmd_utils.py`: | ||
|
||
```python | ||
from gym_unity.envs import UnityEnv | ||
|
||
def make_unity_env(env_directory, num_env, visual, start_index=0): | ||
""" | ||
Create a wrapped, monitored Unity environment. | ||
""" | ||
def make_env(rank): # pylint: disable=C0111 | ||
def _thunk(): | ||
env = UnityEnv(env_directory, rank, use_visual=True) | ||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) | ||
return env | ||
return _thunk | ||
if visual: | ||
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) | ||
else: | ||
rank = MPI.COMM_WORLD.Get_rank() | ||
env = UnityEnv(env_directory, rank, use_visual=False) | ||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) | ||
return env | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from gym.envs.registration import register |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from gym_unity.envs.unity_env import UnityEnv, UnityGymException |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
import gym | ||
import numpy as np | ||
from unityagents import UnityEnvironment | ||
from gym import error, spaces, logger | ||
|
||
|
||
class UnityGymException(error.Error): | ||
""" | ||
Any error related to the gym wrapper of ml-agents. | ||
""" | ||
pass | ||
|
||
|
||
class UnityEnv(gym.Env): | ||
""" | ||
Provides Gym wrapper for Unity Learning Environments. | ||
Multi-agent environments use lists for object types, as done here: | ||
https://github.com/openai/multiagent-particle-envs | ||
""" | ||
|
||
def __init__(self, environment_filename: str, worker_id=0, use_visual=False, multiagent=False): | ||
""" | ||
Environment initialization | ||
:param environment_filename: The UnityEnvironment path or file to be wrapped in the gym. | ||
:param worker_id: Worker number for environment. | ||
:param use_visual: Whether to use visual observation or vector observation. | ||
:param multiagent: Whether to run in multi-agent mode (lists of obs, reward, done). | ||
""" | ||
self._env = UnityEnvironment(environment_filename, worker_id) | ||
self.name = self._env.academy_name | ||
self.visual_obs = None | ||
self._current_state = None | ||
self._n_agents = None | ||
self._multiagent = multiagent | ||
|
||
# Check brain configuration | ||
if len(self._env.brains) != 1: | ||
raise UnityGymException( | ||
"There can only be one brain in a UnityEnvironment " | ||
"if it is wrapped in a gym.") | ||
self.brain_name = self._env.external_brain_names[0] | ||
brain = self._env.brains[self.brain_name] | ||
|
||
if use_visual and brain.number_visual_observations == 0: | ||
raise UnityGymException("`use_visual` was set to True, however there are no" | ||
" visual observations as part of this environment.") | ||
self.use_visual = brain.number_visual_observations == 1 and use_visual | ||
|
||
if brain.num_stacked_vector_observations != 1: | ||
raise UnityGymException( | ||
"There can only be one stacked vector observation in a UnityEnvironment " | ||
"if it is wrapped in a gym.") | ||
|
||
# Check for number of agents in scene. | ||
initial_info = self._env.reset()[self.brain_name] | ||
self._check_agents(len(initial_info.agents)) | ||
|
||
# Set observation and action spaces | ||
if brain.vector_action_space_type == "discrete": | ||
if len(brain.vector_action_space_size) == 1: | ||
self._action_space = spaces.Discrete(brain.vector_action_space_size[0]) | ||
else: | ||
self._action_space = spaces.MultiDiscrete(brain.vector_action_space_size) | ||
else: | ||
high = np.array([1] * brain.vector_action_space_size[0]) | ||
self._action_space = spaces.Box(-high, high, dtype=np.float32) | ||
high = np.array([np.inf] * brain.vector_observation_space_size) | ||
self.action_meanings = brain.vector_action_descriptions | ||
if self.use_visual: | ||
if brain.camera_resolutions[0]["blackAndWhite"]: | ||
depth = 1 | ||
else: | ||
depth = 3 | ||
self._observation_space = spaces.Box(0, 1, dtype=np.float32, | ||
shape=(brain.camera_resolutions[0]["height"], | ||
brain.camera_resolutions[0]["width"], | ||
depth)) | ||
else: | ||
self._observation_space = spaces.Box(-high, high, dtype=np.float32) | ||
|
||
def reset(self): | ||
"""Resets the state of the environment and returns an initial observation. | ||
In the case of multi-agent environments, this is a list. | ||
Returns: observation (object/list): the initial observation of the | ||
space. | ||
""" | ||
info = self._env.reset()[self.brain_name] | ||
n_agents = len(info.agents) | ||
self._check_agents(n_agents) | ||
|
||
if not self._multiagent: | ||
obs, reward, done, info = self._single_step(info) | ||
else: | ||
obs, reward, done, info = self._multi_step(info) | ||
return obs | ||
|
||
def step(self, action): | ||
"""Run one timestep of the environment's dynamics. When end of | ||
episode is reached, you are responsible for calling `reset()` | ||
to reset this environment's state. | ||
Accepts an action and returns a tuple (observation, reward, done, info). | ||
In the case of multi-agent environments, these are lists. | ||
Args: | ||
action (object/list): an action provided by the environment | ||
Returns: | ||
observation (object/list): agent's observation of the current environment | ||
reward (float/list) : amount of reward returned after previous action | ||
done (boolean/list): whether the episode has ended. | ||
info (dict): contains auxiliary diagnostic information, including BrainInfo. | ||
""" | ||
|
||
# Use random actions for all other agents in environment. | ||
if self._multiagent: | ||
if not isinstance(action, list): | ||
raise UnityGymException("The environment was expecting `action` to be a list.") | ||
if len(action) != self._n_agents: | ||
raise UnityGymException("The environment was expecting a list of {} actions.".format(self._n_agents)) | ||
else: | ||
action = np.array(action) | ||
|
||
info = self._env.step(action)[self.brain_name] | ||
n_agents = len(info.agents) | ||
self._check_agents(n_agents) | ||
self._current_state = info | ||
|
||
if not self._multiagent: | ||
obs, reward, done, info = self._single_step(info) | ||
else: | ||
obs, reward, done, info = self._multi_step(info) | ||
return obs, reward, done, info | ||
|
||
def _single_step(self, info): | ||
if self.use_visual: | ||
self.visual_obs = info.visual_observations[0][0, :, :, :] | ||
default_observation = self.visual_obs | ||
else: | ||
default_observation = info.vector_observations[0, :] | ||
|
||
return default_observation, info.rewards[0], info.local_done[0], {"text_observation": info.text_observations[0], | ||
"brain_info": info} | ||
|
||
def _multi_step(self, info): | ||
if self.use_visual: | ||
self.visual_obs = info.visual_observations | ||
default_observation = self.visual_obs | ||
else: | ||
default_observation = info.vector_observations | ||
return list(default_observation), info.rewards, info.local_done, {"text_observation": info.text_observations, | ||
"brain_info": info} | ||
|
||
def render(self, mode='rgb_array'): | ||
return self.visual_obs | ||
|
||
def close(self): | ||
"""Override _close in your subclass to perform any necessary cleanup. | ||
Environments will automatically close() themselves when | ||
garbage collected or when the program exits. | ||
""" | ||
self._env.close() | ||
|
||
def get_action_meanings(self): | ||
return self.action_meanings | ||
|
||
def seed(self, seed=None): | ||
"""Sets the seed for this env's random number generator(s). | ||
Currently not implemented. | ||
""" | ||
logger.warn("Could not seed environment %s", self.name) | ||
return | ||
|
||
def _check_agents(self, n_agents): | ||
if not self._multiagent and n_agents > 1: | ||
raise UnityGymException("The environment was launched as a single-agent environment, however" | ||
"there is more than one agent in the scene.") | ||
elif self._multiagent and n_agents <= 1: | ||
raise UnityGymException("The environment was launched as a mutli-agent environment, however" | ||
"there is only one agent in the scene.") | ||
if self._n_agents is None: | ||
self._n_agents = n_agents | ||
logger.info("{} agents within environment.".format(n_agents)) | ||
elif self._n_agents != n_agents: | ||
raise UnityGymException("The number of agents in the environment has changed since " | ||
"initialization. This is not supported.") | ||
|
||
@property | ||
def metadata(self): | ||
return {'render.modes': ['rgb_array']} | ||
|
||
@property | ||
def reward_range(self): | ||
return -float('inf'), float('inf') | ||
|
||
@property | ||
def spec(self): | ||
return None | ||
|
||
@property | ||
def action_space(self): | ||
return self._action_space | ||
|
||
@property | ||
def observation_space(self): | ||
return self._observation_space | ||
|
||
@property | ||
def number_agents(self): | ||
return self._n_agents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/usr/bin/env python | ||
|
||
from setuptools import setup, Command, find_packages | ||
|
||
setup(name='gym_unity', | ||
version='0.1.0', | ||
description='Unity Machine Learning Agents Gym Interface', | ||
license='Apache License 2.0', | ||
author='Unity Technologies', | ||
author_email='[email protected]', | ||
url='https://github.com/Unity-Technologies/ml-agents', | ||
packages=find_packages(), | ||
install_requires = ['gym', 'unityagents'] | ||
) |
Oops, something went wrong.