diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6769e21 --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7e5680c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +FROM nvidia/cuda:11.3.1-base-ubuntu20.04 + +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + && rm -rf /var/lib/apt/lists/* + +# Create a working directory +RUN mkdir /app +WORKDIR /app + +# Create a non-root user and switch to it +RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ + && chown -R user:user /app +RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user +USER user + +# All users can use /home/user as their home directory +ENV HOME=/home/user +#COPY . /home/user/prm-rl +RUN mkdir $HOME/.cache $HOME/.config \ + && sudo chmod -R 777 $HOME + +# Set up the Conda environment +ENV CONDA_AUTO_UPDATE_CONDA=false \ + PATH=$HOME/miniconda/bin:$PATH +COPY environment.yml /app/environment.yml +RUN curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py39_4.10.3-Linux-x86_64.sh \ + && chmod +x ~/miniconda.sh \ + && ~/miniconda.sh -b -p ~/miniconda \ + && rm ~/miniconda.sh \ + && conda env update -n base -f /app/environment.yml \ + && rm /app/environment.yml \ + && conda clean -ya + +RUN sudo rm /etc/apt/sources.list.d/cuda.list +RUN sudo rm /etc/apt/sources.list.d/nvidia-ml.list + +RUN sudo apt-get update -q && sudo DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libgl1-mesa-dev \ + libgl1-mesa-glx \ + libglew-dev \ + libosmesa6-dev \ + software-properties-common \ + vim \ + wget \ + gcc \ + && sudo apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \ + && sudo chmod +x /usr/local/bin/patchelf + +RUN sudo mkdir -p /home/user/.mujoco \ + && sudo wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ + && sudo tar -xf mujoco.tar.gz -C /home/user/.mujoco \ + && sudo rm mujoco.tar.gz diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..77e77f9 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2021 Decision Transformer (Decision Transformer: Reinforcement Learning via Sequence Modeling) Authors (https://arxiv.org/abs/2106.01345) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..65e54c2 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ + +# Decision Transformer + +Lili Chen\*, Kevin Lu\*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor Mordatch† + +\*equal contribution, †equal advising + +A link to our paper can be found on [arXiv](https://arxiv.org/abs/2106.01345). + +## Overview + +Official codebase for [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://sites.google.com/berkeley.edu/decision-transformer). +Contains scripts to reproduce experiments. + +![image info](./architecture.png) + +## Instructions + +We provide code in two sub-directories: `atari` containing code for Atari experiments and `gym` containing code for OpenAI Gym experiments. +See corresponding READMEs in each folder for instructions; scripts should be run from the respective directories. +It may be necessary to add the respective directories to your PYTHONPATH. + +## Citation + +Please cite our paper as: + +``` +@article{chen2021decisiontransformer, + title={Decision Transformer: Reinforcement Learning via Sequence Modeling}, + author={Lili Chen and Kevin Lu and Aravind Rajeswaran and Kimin Lee and Aditya Grover and Michael Laskin and Pieter Abbeel and Aravind Srinivas and Igor Mordatch}, + journal={arXiv preprint arXiv:2106.01345}, + year={2021} +} +``` + +Note: this is not an official Google or Facebook product. + +## License + +MIT diff --git a/architecture.png b/architecture.png new file mode 100644 index 0000000..bd2e104 Binary files /dev/null and b/architecture.png differ diff --git a/atari/LICENSE b/atari/LICENSE new file mode 100644 index 0000000..3d89960 --- /dev/null +++ b/atari/LICENSE @@ -0,0 +1,7 @@ +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/atari/conda_env.yml b/atari/conda_env.yml new file mode 100644 index 0000000..baff943 --- /dev/null +++ b/atari/conda_env.yml @@ -0,0 +1,22 @@ +name: decision-transformer-atari +channels: +- pytorch +dependencies: +- python=3.7.9 +- pytorch=1.2 +- cudatoolkit=10. +- numpy +- psutil +- opencv +- pip +- pip: + - atari-py + - pyprind + - tensorflow-gpu>=1.13 + - absl-py + - atari-py + - gin-config + - gym + - tqdm + - blosc + - git+https://github.com/google/dopamine.git diff --git a/atari/create_dataset.py b/atari/create_dataset.py new file mode 100644 index 0000000..c10c6fe --- /dev/null +++ b/atari/create_dataset.py @@ -0,0 +1,102 @@ +import csv +import logging +# make deterministic +from mingpt.utils import set_seed +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from torch.utils.data import Dataset +from mingpt.model_atari import GPT, GPTConfig +from mingpt.trainer_atari import Trainer, TrainerConfig +from mingpt.utils import sample +from collections import deque +import random +import torch +import pickle +import blosc +import argparse +from fixed_replay_buffer import FixedReplayBuffer + +def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_per_buffer): + # -- load data from memory (make more efficient) + obss = [] + actions = [] + returns = [0] + done_idxs = [] + stepwise_returns = [] + + transitions_per_buffer = np.zeros(50, dtype=int) + num_trajectories = 0 + while len(obss) < num_steps: + buffer_num = np.random.choice(np.arange(50 - num_buffers, 50), 1)[0] + i = transitions_per_buffer[buffer_num] + print('loading from buffer %d which has %d already loaded' % (buffer_num, i)) + frb = FixedReplayBuffer( + data_dir=data_dir_prefix + game + '/1/replay_logs', + replay_suffix=buffer_num, + observation_shape=(84, 84), + stack_size=4, + update_horizon=1, + gamma=0.99, + observation_dtype=np.uint8, + batch_size=32, + replay_capacity=100000) + if frb._loaded_buffers: + done = False + curr_num_transitions = len(obss) + trajectories_to_load = trajectories_per_buffer + while not done: + states, ac, ret, next_states, next_action, next_reward, terminal, indices = frb.sample_transition_batch(batch_size=1, indices=[i]) + states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) + obss += [states] + actions += [ac[0]] + stepwise_returns += [ret[0]] + if terminal[0]: + done_idxs += [len(obss)] + returns += [0] + if trajectories_to_load == 0: + done = True + else: + trajectories_to_load -= 1 + returns[-1] += ret[0] + i += 1 + if i >= 100000: + obss = obss[:curr_num_transitions] + actions = actions[:curr_num_transitions] + stepwise_returns = stepwise_returns[:curr_num_transitions] + returns[-1] = 0 + i = transitions_per_buffer[buffer_num] + done = True + num_trajectories += (trajectories_per_buffer - trajectories_to_load) + transitions_per_buffer[buffer_num] = i + print('this buffer has %d loaded transitions and there are now %d transitions total divided into %d trajectories' % (i, len(obss), num_trajectories)) + + actions = np.array(actions) + returns = np.array(returns) + stepwise_returns = np.array(stepwise_returns) + done_idxs = np.array(done_idxs) + + # -- create reward-to-go dataset + start_index = 0 + rtg = np.zeros_like(stepwise_returns) + for i in done_idxs: + i = int(i) + curr_traj_returns = stepwise_returns[start_index:i] + for j in range(i-1, start_index-1, -1): # start from i-1 + rtg_j = curr_traj_returns[j-start_index:i-start_index] + rtg[j] = sum(rtg_j) + start_index = i + print('max rtg is %d' % max(rtg)) + + # -- create timestep dataset + start_index = 0 + timesteps = np.zeros(len(actions)+1, dtype=int) + for i in done_idxs: + i = int(i) + timesteps[start_index:i+1] = np.arange(i+1 - start_index) + start_index = i+1 + print('max timestep is %d' % max(timesteps)) + + return obss, actions, returns, done_idxs, rtg, timesteps diff --git a/atari/fixed_replay_buffer.py b/atari/fixed_replay_buffer.py new file mode 100644 index 0000000..bfa94ce --- /dev/null +++ b/atari/fixed_replay_buffer.py @@ -0,0 +1,109 @@ +# source: https://github.com/google-research/batch_rl/blob/master/batch_rl/fixed_replay/replay_memory/fixed_replay_buffer.py + +import collections +from concurrent import futures +from dopamine.replay_memory import circular_replay_buffer +import numpy as np +import tensorflow.compat.v1 as tf +import gin + +gfile = tf.gfile + +STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX + +class FixedReplayBuffer(object): + """Object composed of a list of OutofGraphReplayBuffers.""" + + def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Initialize the FixedReplayBuffer class. + Args: + data_dir: str, log Directory from which to load the replay buffer. + replay_suffix: int, If not None, then only load the replay buffer + corresponding to the specific suffix in data directory. + *args: Arbitrary extra arguments. + **kwargs: Arbitrary keyword arguments. + """ + self._args = args + self._kwargs = kwargs + self._data_dir = data_dir + self._loaded_buffers = False + self.add_count = np.array(0) + self._replay_suffix = replay_suffix + if not self._loaded_buffers: + if replay_suffix is not None: + assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' + self.load_single_buffer(replay_suffix) + else: + self._load_replay_buffers(num_buffers=50) + + def load_single_buffer(self, suffix): + """Load a single replay buffer.""" + replay_buffer = self._load_buffer(suffix) + if replay_buffer is not None: + self._replay_buffers = [replay_buffer] + self.add_count = replay_buffer.add_count + self._num_replay_buffers = 1 + self._loaded_buffers = True + + def _load_buffer(self, suffix): + """Loads a OutOfGraphReplayBuffer replay buffer.""" + try: + # pytype: disable=attribute-error + replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( + *self._args, **self._kwargs) + replay_buffer.load(self._data_dir, suffix) + tf.logging.info('Loaded replay buffer ckpt {} from {}'.format( + suffix, self._data_dir)) + # pytype: enable=attribute-error + return replay_buffer + except tf.errors.NotFoundError: + return None + + def _load_replay_buffers(self, num_buffers=None): + """Loads multiple checkpoints into a list of replay buffers.""" + if not self._loaded_buffers: # pytype: disable=attribute-error + ckpts = gfile.ListDirectory(self._data_dir) # pytype: disable=attribute-error + # Assumes that the checkpoints are saved in a format CKPT_NAME.{SUFFIX}.gz + ckpt_counters = collections.Counter( + [name.split('.')[-2] for name in ckpts]) + # Should contain the files for add_count, action, observation, reward, + # terminal and invalid_range + ckpt_suffixes = [x for x in ckpt_counters if ckpt_counters[x] in [6, 7]] + if num_buffers is not None: + ckpt_suffixes = np.random.choice( + ckpt_suffixes, num_buffers, replace=False) + self._replay_buffers = [] + # Load the replay buffers in parallel + with futures.ThreadPoolExecutor( + max_workers=num_buffers) as thread_pool_executor: + replay_futures = [thread_pool_executor.submit( + self._load_buffer, suffix) for suffix in ckpt_suffixes] + for f in replay_futures: + replay_buffer = f.result() + if replay_buffer is not None: + self._replay_buffers.append(replay_buffer) + self.add_count = max(replay_buffer.add_count, self.add_count) + self._num_replay_buffers = len(self._replay_buffers) + if self._num_replay_buffers: + self._loaded_buffers = True + + def get_transition_elements(self): + return self._replay_buffers[0].get_transition_elements() + + def sample_transition_batch(self, batch_size=None, indices=None): + buffer_index = np.random.randint(self._num_replay_buffers) + return self._replay_buffers[buffer_index].sample_transition_batch( + batch_size=batch_size, indices=indices) + + def load(self, *args, **kwargs): # pylint: disable=unused-argument + pass + + def reload_buffer(self, num_buffers=None): + self._loaded_buffers = False + self._load_replay_buffers(num_buffers) + + def save(self, *args, **kwargs): # pylint: disable=unused-argument + pass + + def add(self, *args, **kwargs): # pylint: disable=unused-argument + pass \ No newline at end of file diff --git a/atari/mingpt/__init__.py b/atari/mingpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/atari/mingpt/model_atari.py b/atari/mingpt/model_atari.py new file mode 100644 index 0000000..8428811 --- /dev/null +++ b/atari/mingpt/model_atari.py @@ -0,0 +1,281 @@ +""" +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +""" +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +import math +import logging + +import torch +import torch.nn as nn +from torch.nn import functional as F + +logger = logging.getLogger(__name__) + +import numpy as np + +class GELU(nn.Module): + def forward(self, input): + return F.gelu(input) + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, vocab_size, block_size, **kwargs): + self.vocab_size = vocab_size + self.block_size = block_size + for k,v in kwargs.items(): + setattr(self, k, v) + +class GPT1Config(GPTConfig): + """ GPT-1 like network roughly 125M params """ + n_layer = 12 + n_head = 12 + n_embd = 768 + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) + # .view(1, 1, config.block_size, config.block_size)) + self.register_buffer("mask", torch.tril(torch.ones(config.block_size + 1, config.block_size + 1)) + .view(1, 1, config.block_size + 1, config.block_size + 1)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + GELU(), + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, config): + super().__init__() + + self.config = config + + self.model_type = config.model_type + + # input embedding stem + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + # self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd)) + self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep+1, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.block_size = config.block_size + self.apply(self._init_weights) + + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + + self.state_encoder = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), + nn.Flatten(), nn.Linear(3136, config.n_embd), nn.Tanh()) + + self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh()) + + self.action_embeddings = nn.Sequential(nn.Embedding(config.vocab_size, config.n_embd), nn.Tanh()) + nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def configure_optimizers(self, train_config): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + # whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + no_decay.add('global_pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + # state, action, and return + def forward(self, states, actions, targets=None, rtgs=None, timesteps=None): + # states: (batch, block_size, 4*84*84) + # actions: (batch, block_size, 1) + # targets: (batch, block_size, 1) + # rtgs: (batch, block_size, 1) + # timesteps: (batch, 1, 1) + + state_embeddings = self.state_encoder(states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()) # (batch * block_size, n_embd) + state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd) # (batch, block_size, n_embd) + + if actions is not None and self.model_type == 'reward_conditioned': + rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) + action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) + + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*3 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings[:,::3,:] = rtg_embeddings + token_embeddings[:,1::3,:] = state_embeddings + token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:] + elif actions is None and self.model_type == 'reward_conditioned': # only happens at very first timestep of evaluation + rtg_embeddings = self.ret_emb(rtgs.type(torch.float32)) + + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2, self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings[:,::2,:] = rtg_embeddings # really just [:,0,:] + token_embeddings[:,1::2,:] = state_embeddings # really just [:,1,:] + elif actions is not None and self.model_type == 'naive': + action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd) + + token_embeddings = torch.zeros((states.shape[0], states.shape[1]*2 - int(targets is None), self.config.n_embd), dtype=torch.float32, device=state_embeddings.device) + token_embeddings[:,::2,:] = state_embeddings + token_embeddings[:,1::2,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:] + elif actions is None and self.model_type == 'naive': # only happens at very first timestep of evaluation + token_embeddings = state_embeddings + else: + raise NotImplementedError() + + batch_size = states.shape[0] + all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd + + position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :] + + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + if actions is not None and self.model_type == 'reward_conditioned': + logits = logits[:, 1::3, :] # only keep predictions from state_embeddings + elif actions is None and self.model_type == 'reward_conditioned': + logits = logits[:, 1:, :] + elif actions is not None and self.model_type == 'naive': + logits = logits[:, ::2, :] # only keep predictions from state_embeddings + elif actions is None and self.model_type == 'naive': + logits = logits # for completeness + else: + raise NotImplementedError() + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) + + return logits, loss diff --git a/atari/mingpt/trainer_atari.py b/atari/mingpt/trainer_atari.py new file mode 100644 index 0000000..9ac79e0 --- /dev/null +++ b/atari/mingpt/trainer_atari.py @@ -0,0 +1,319 @@ +""" +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +""" +Simple training loop; Boilerplate that could apply to any arbitrary neural network, +so nothing in this file really has anything to do with GPT specifically. +""" + +import math +import logging + +from tqdm import tqdm +import numpy as np + +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data.dataloader import DataLoader + +logger = logging.getLogger(__name__) + +from mingpt.utils import sample +import atari_py +from collections import deque +import random +import cv2 +import torch +from PIL import Image + +class TrainerConfig: + # optimization parameters + max_epochs = 10 + batch_size = 64 + learning_rate = 3e-4 + betas = (0.9, 0.95) + grad_norm_clip = 1.0 + weight_decay = 0.1 # only applied on matmul weights + # learning rate decay params: linear warmup followed by cosine decay to 10% of original + lr_decay = False + warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere + final_tokens = 260e9 # (at what point we reach 10% of original LR) + # checkpoint settings + ckpt_path = None + num_workers = 0 # for DataLoader + + def __init__(self, **kwargs): + for k,v in kwargs.items(): + setattr(self, k, v) + +class Trainer: + + def __init__(self, model, train_dataset, test_dataset, config): + self.model = model + self.train_dataset = train_dataset + self.test_dataset = test_dataset + self.config = config + + # take over whatever gpus are on the system + self.device = 'cpu' + if torch.cuda.is_available(): + self.device = torch.cuda.current_device() + self.model = torch.nn.DataParallel(self.model).to(self.device) + + def save_checkpoint(self): + # DataParallel wrappers keep raw model object in .module attribute + raw_model = self.model.module if hasattr(self.model, "module") else self.model + logger.info("saving %s", self.config.ckpt_path) + # torch.save(raw_model.state_dict(), self.config.ckpt_path) + + def train(self): + model, config = self.model, self.config + raw_model = model.module if hasattr(self.model, "module") else model + optimizer = raw_model.configure_optimizers(config) + + def run_epoch(split, epoch_num=0): + is_train = split == 'train' + model.train(is_train) + data = self.train_dataset if is_train else self.test_dataset + loader = DataLoader(data, shuffle=True, pin_memory=True, + batch_size=config.batch_size, + num_workers=config.num_workers) + + losses = [] + pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader) + for it, (x, y, r, t) in pbar: + + # place data on the correct device + x = x.to(self.device) + y = y.to(self.device) + r = r.to(self.device) + t = t.to(self.device) + + # forward the model + with torch.set_grad_enabled(is_train): + # logits, loss = model(x, y, r) + logits, loss = model(x, y, y, r, t) + loss = loss.mean() # collapse all losses if they are scattered on multiple gpus + losses.append(loss.item()) + + if is_train: + + # backprop and update the parameters + model.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) + optimizer.step() + + # decay the learning rate based on our progress + if config.lr_decay: + self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) + if self.tokens < config.warmup_tokens: + # linear warmup + lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) + else: + # cosine learning rate decay + progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) + lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) + lr = config.learning_rate * lr_mult + for param_group in optimizer.param_groups: + param_group['lr'] = lr + else: + lr = config.learning_rate + + # report progress + pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}") + + if not is_train: + test_loss = float(np.mean(losses)) + logger.info("test loss: %f", test_loss) + return test_loss + + # best_loss = float('inf') + + best_return = -float('inf') + + self.tokens = 0 # counter used for learning rate decay + + for epoch in range(config.max_epochs): + + run_epoch('train', epoch_num=epoch) + # if self.test_dataset is not None: + # test_loss = run_epoch('test') + + # # supports early stopping based on the test loss, or just save always if no test set is provided + # good_model = self.test_dataset is None or test_loss < best_loss + # if self.config.ckpt_path is not None and good_model: + # best_loss = test_loss + # self.save_checkpoint() + + # -- pass in target returns + if self.config.model_type == 'naive': + eval_return = self.get_returns(0) + elif self.config.model_type == 'reward_conditioned': + if self.config.game == 'Breakout': + eval_return = self.get_returns(90) + elif self.config.game == 'Seaquest': + eval_return = self.get_returns(1150) + elif self.config.game == 'Qbert': + eval_return = self.get_returns(14000) + elif self.config.game == 'Pong': + eval_return = self.get_returns(20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + def get_returns(self, ret): + self.model.train(False) + args=Args(self.config.game.lower(), self.config.seed) + env = Env(args) + env.eval() + + T_rewards, T_Qs = [], [] + done = True + for i in range(10): + state = env.reset() + state = state.type(torch.float32).to(self.device).unsqueeze(0).unsqueeze(0) + rtgs = [ret] + # first state is from env, first rtg is target return, and first timestep is 0 + sampled_action = sample(self.model.module, state, 1, temperature=1.0, sample=True, actions=None, + rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1), + timesteps=torch.zeros((1, 1, 1), dtype=torch.int64).to(self.device)) + + j = 0 + all_states = state + actions = [] + while True: + if done: + state, reward_sum, done = env.reset(), 0, False + action = sampled_action.cpu().numpy()[0,-1] + actions += [sampled_action] + state, reward, done = env.step(action) + reward_sum += reward + j += 1 + + if done: + T_rewards.append(reward_sum) + break + + state = state.unsqueeze(0).unsqueeze(0).to(self.device) + + all_states = torch.cat([all_states, state], dim=0) + + rtgs += [rtgs[-1] - reward] + # all_states has all previous states and rtgs has all previous rtgs (will be cut to block_size in utils.sample) + # timestep is just current timestep + sampled_action = sample(self.model.module, all_states.unsqueeze(0), 1, temperature=1.0, sample=True, + actions=torch.tensor(actions, dtype=torch.long).to(self.device).unsqueeze(1).unsqueeze(0), + rtgs=torch.tensor(rtgs, dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1), + timesteps=(min(j, self.config.max_timestep) * torch.ones((1, 1, 1), dtype=torch.int64).to(self.device))) + env.close() + eval_return = sum(T_rewards)/10. + print("target return: %d, eval return: %d" % (ret, eval_return)) + self.model.train(True) + return eval_return + + +class Env(): + def __init__(self, args): + self.device = args.device + self.ale = atari_py.ALEInterface() + self.ale.setInt('random_seed', args.seed) + self.ale.setInt('max_num_frames_per_episode', args.max_episode_length) + self.ale.setFloat('repeat_action_probability', 0) # Disable sticky actions + self.ale.setInt('frame_skip', 0) + self.ale.setBool('color_averaging', False) + self.ale.loadROM(atari_py.get_game_path(args.game)) # ROM loading must be done after setting options + actions = self.ale.getMinimalActionSet() + self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions)) + self.lives = 0 # Life counter (used in DeepMind training) + self.life_termination = False # Used to check if resetting only from loss of life + self.window = args.history_length # Number of frames to concatenate + self.state_buffer = deque([], maxlen=args.history_length) + self.training = True # Consistent with model training mode + + def _get_state(self): + state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR) + return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255) + + def _reset_buffer(self): + for _ in range(self.window): + self.state_buffer.append(torch.zeros(84, 84, device=self.device)) + + def reset(self): + if self.life_termination: + self.life_termination = False # Reset flag + self.ale.act(0) # Use a no-op after loss of life + else: + # Reset internals + self._reset_buffer() + self.ale.reset_game() + # Perform up to 30 random no-ops before starting + for _ in range(random.randrange(30)): + self.ale.act(0) # Assumes raw action 0 is always no-op + if self.ale.game_over(): + self.ale.reset_game() + # Process and return "initial" state + observation = self._get_state() + self.state_buffer.append(observation) + self.lives = self.ale.lives() + return torch.stack(list(self.state_buffer), 0) + + def step(self, action): + # Repeat action 4 times, max pool over last 2 frames + frame_buffer = torch.zeros(2, 84, 84, device=self.device) + reward, done = 0, False + for t in range(4): + reward += self.ale.act(self.actions.get(action)) + if t == 2: + frame_buffer[0] = self._get_state() + elif t == 3: + frame_buffer[1] = self._get_state() + done = self.ale.game_over() + if done: + break + observation = frame_buffer.max(0)[0] + self.state_buffer.append(observation) + # Detect loss of life as terminal in training mode + if self.training: + lives = self.ale.lives() + if lives < self.lives and lives > 0: # Lives > 0 for Q*bert + self.life_termination = not done # Only set flag when not truly done + done = True + self.lives = lives + # Return state, reward, done + return torch.stack(list(self.state_buffer), 0), reward, done + + # Uses loss of life as terminal signal + def train(self): + self.training = True + + # Uses standard terminal signal + def eval(self): + self.training = False + + def action_space(self): + return len(self.actions) + + def render(self): + cv2.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1]) + cv2.waitKey(1) + + def close(self): + cv2.destroyAllWindows() + +class Args: + def __init__(self, game, seed): + self.device = torch.device('cuda') + self.seed = seed + self.max_episode_length = 108e3 + self.game = game + self.history_length = 4 diff --git a/atari/mingpt/utils.py b/atari/mingpt/utils.py new file mode 100644 index 0000000..b97ff5b --- /dev/null +++ b/atari/mingpt/utils.py @@ -0,0 +1,62 @@ +""" +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +import random +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +@torch.no_grad() +def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, actions=None, rtgs=None, timesteps=None): + """ + take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in + the sequence, feeding the predictions back into the model each time. Clearly the sampling + has quadratic complexity unlike an RNN that is only linear, and has a finite context window + of block_size, unlike an RNN that has an infinite context window. + """ + block_size = model.get_block_size() + model.eval() + for k in range(steps): + # x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed + x_cond = x if x.size(1) <= block_size//3 else x[:, -block_size//3:] # crop context if needed + if actions is not None: + actions = actions if actions.size(1) <= block_size//3 else actions[:, -block_size//3:] # crop context if needed + rtgs = rtgs if rtgs.size(1) <= block_size//3 else rtgs[:, -block_size//3:] # crop context if needed + logits, _ = model(x_cond, actions=actions, targets=None, rtgs=rtgs, timesteps=timesteps) + # pluck the logits at the final step and scale by temperature + logits = logits[:, -1, :] / temperature + # optionally crop probabilities to only the top k options + if top_k is not None: + logits = top_k_logits(logits, top_k) + # apply softmax to convert to probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution or take the most likely + if sample: + ix = torch.multinomial(probs, num_samples=1) + else: + _, ix = torch.topk(probs, k=1, dim=-1) + # append to the sequence and continue + # x = torch.cat((x, ix), dim=1) + x = ix + + return x diff --git a/atari/readme-atari.md b/atari/readme-atari.md new file mode 100644 index 0000000..54e2d81 --- /dev/null +++ b/atari/readme-atari.md @@ -0,0 +1,28 @@ + +# Atari + +We build our Atari implementation on top of [minGPT](https://github.com/karpathy/minGPT) and benchmark our results on the [DQN-replay](https://github.com/google-research/batch_rl) dataset. + +## Installation + +Dependencies can be installed with the following command: + +``` +conda env create -f conda_env.yml +``` + +## Downloading datasets + +Create a directory for the dataset and load the dataset using [gsutil](https://cloud.google.com/storage/docs/gsutil_install#install). Replace `[DIRECTORY_NAME]` and `[GAME_NAME]` accordingly (e.g., `./dqn_replay` for `[DIRECTORY_NAME]` and `Breakout` for `[GAME_NAME]`) +``` +mkdir [DIRECTORY_NAME] +gsutil -m cp -R gs://atari-replay-datasets/dqn/[GAME_NAME] [DIRECTORY_NAME] +``` + +## Example usage + +Scripts to reproduce our Decision Transformer results can be found in `run.sh`. + +``` +python run_dt_atari.py --seed 123 --block_size 90 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 --data_dir_prefix [DIRECTORY_NAME] +``` diff --git a/atari/run.sh b/atari/run.sh new file mode 100644 index 0000000..c6cb918 --- /dev/null +++ b/atari/run.sh @@ -0,0 +1,41 @@ +# Decision Transformer (DT) +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'reward_conditioned' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 +done + +# Behavior Cloning (BC) +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Breakout' --batch_size 128 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Qbert' --batch_size 128 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 50 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Pong' --batch_size 512 +done + +for seed in 123 231 312 +do + python run_dt_atari.py --seed $seed --context_length 30 --epochs 5 --model_type 'naive' --num_steps 500000 --num_buffers 50 --game 'Seaquest' --batch_size 128 +done \ No newline at end of file diff --git a/atari/run_dt_atari.py b/atari/run_dt_atari.py new file mode 100644 index 0000000..6e14233 --- /dev/null +++ b/atari/run_dt_atari.py @@ -0,0 +1,90 @@ +import csv +import logging +# make deterministic +from mingpt.utils import set_seed +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from torch.utils.data import Dataset +from mingpt.model_atari import GPT, GPTConfig +from mingpt.trainer_atari import Trainer, TrainerConfig +from mingpt.utils import sample +from collections import deque +import random +import torch +import pickle +import blosc +import argparse +from create_dataset import create_dataset + +parser = argparse.ArgumentParser() +parser.add_argument('--seed', type=int, default=123) +parser.add_argument('--context_length', type=int, default=30) +parser.add_argument('--epochs', type=int, default=5) +parser.add_argument('--model_type', type=str, default='reward_conditioned') +parser.add_argument('--num_steps', type=int, default=500000) +parser.add_argument('--num_buffers', type=int, default=50) +parser.add_argument('--game', type=str, default='Breakout') +parser.add_argument('--batch_size', type=int, default=128) +# +parser.add_argument('--trajectories_per_buffer', type=int, default=10, help='Number of trajectories to sample from each of the buffers.') +parser.add_argument('--data_dir_prefix', type=str, default='./dqn_replay/') +args = parser.parse_args() + +set_seed(args.seed) + +class StateActionReturnDataset(Dataset): + + def __init__(self, data, block_size, actions, done_idxs, rtgs, timesteps): + self.block_size = block_size + self.vocab_size = max(actions) + 1 + self.data = data + self.actions = actions + self.done_idxs = done_idxs + self.rtgs = rtgs + self.timesteps = timesteps + + def __len__(self): + return len(self.data) - self.block_size + + def __getitem__(self, idx): + block_size = self.block_size // 3 + done_idx = idx + block_size + for i in self.done_idxs: + if i > idx: # first done_idx greater than idx + done_idx = min(int(i), done_idx) + break + idx = done_idx - block_size + states = torch.tensor(np.array(self.data[idx:done_idx]), dtype=torch.float32).reshape(block_size, -1) # (block_size, 4*84*84) + states = states / 255. + actions = torch.tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) + rtgs = torch.tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) + timesteps = torch.tensor(self.timesteps[idx:idx+1], dtype=torch.int64).unsqueeze(1) + + return states, actions, rtgs, timesteps + +obss, actions, returns, done_idxs, rtgs, timesteps = create_dataset(args.num_buffers, args.num_steps, args.game, args.data_dir_prefix, args.trajectories_per_buffer) + +# set up logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +train_dataset = StateActionReturnDataset(obss, args.context_length*3, actions, done_idxs, rtgs, timesteps) + +mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, + n_layer=6, n_head=8, n_embd=128, model_type=args.model_type, max_timestep=max(timesteps)) +model = GPT(mconf) + +# initialize a trainer instance and kick off training +epochs = args.epochs +tconf = TrainerConfig(max_epochs=epochs, batch_size=args.batch_size, learning_rate=6e-4, + lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*args.context_length*3, + num_workers=4, seed=args.seed, model_type=args.model_type, game=args.game, max_timestep=max(timesteps)) +trainer = Trainer(model, train_dataset, None, tconf) + +trainer.train() diff --git a/d4rl/.gitignore b/d4rl/.gitignore new file mode 100644 index 0000000..e4a1406 --- /dev/null +++ b/d4rl/.gitignore @@ -0,0 +1,130 @@ +.idea +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/d4rl/LICENSE b/d4rl/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/d4rl/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/d4rl/MANIFEST.in b/d4rl/MANIFEST.in new file mode 100644 index 0000000..c359966 --- /dev/null +++ b/d4rl/MANIFEST.in @@ -0,0 +1,3 @@ +recursive-include * *.xml +recursive-include * *.stl +recursive-include * *.png diff --git a/d4rl/README.md b/d4rl/README.md new file mode 100644 index 0000000..92cdb16 --- /dev/null +++ b/d4rl/README.md @@ -0,0 +1,113 @@ +# D4RL: Datasets for Deep Data-Driven Reinforcement Learning +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +[![License](https://licensebuttons.net/l/by/3.0/88x31.png)](https://creativecommons.org/licenses/by/4.0/) + +D4RL is an open-source benchmark for offline reinforcement learning. It provides standardized environments and datasets for training and benchmarking algorithms. A supplementary [whitepaper](https://arxiv.org/abs/2004.07219) and [website](https://sites.google.com/view/d4rl/home) are also available. + +## Setup + +D4RL can be installed by cloning the repository as follows: +``` +git clone https://github.com/rail-berkeley/d4rl.git +cd d4rl +pip install -e . +``` + +Or, alternatively: +``` +pip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl +``` + +The control environments require MuJoCo as a dependency. You may need to obtain a [license](https://www.roboti.us/license.html) and follow the setup instructions for mujoco_py. This mostly involves copying the key to your MuJoCo installation folder. + +The Flow and CARLA tasks also require additional installation steps: +- Instructions for installing CARLA can be found [here](https://github.com/rail-berkeley/d4rl/wiki/CARLA-Setup) +- Instructions for installing Flow can be found [here](https://flow.readthedocs.io/en/latest/flow_setup.html). Make sure to install using the SUMO simulator, and add the flow repository to your PYTHONPATH once finished. + +## Using d4rl + +d4rl uses the [OpenAI Gym](https://github.com/openai/gym) API. Tasks are created via the `gym.make` function. A full list of all tasks is [available here](https://github.com/rail-berkeley/d4rl/wiki/Tasks). + +Each task is associated with a fixed offline dataset, which can be obtained with the `env.get_dataset()` method. This method returns a dictionary with: +- `observations`: An N by observation dimensional array of observations. +- `actions`: An N by action dimensional array of actions. +- `rewards`: An N dimensional array of rewards. +- `terminals`: An N dimensional array of episode termination flags. This is true when episodes end due to termination conditions such as falling over. +- `timeouts`: An N dimensional array of termination flags. This is true when episodes end due to reaching the maximum episode length. +- `infos`: Contains optional task-specific debugging information. + +You can also load data using `d4rl.qlearning_dataset(env)`, which formats the data for use by typical Q-learning algorithms by adding a `next_observations` key. + +```python +import gym +import d4rl # Import required to register environments + +# Create the environment +env = gym.make('maze2d-umaze-v1') + +# d4rl abides by the OpenAI gym interface +env.reset() +env.step(env.action_space.sample()) + +# Each task is associated with a dataset +# dataset contains observations, actions, rewards, terminals, and infos +dataset = env.get_dataset() +print(dataset['observations']) # An N x dim_observation Numpy array of observations + +# Alternatively, use d4rl.qlearning_dataset which +# also adds next_observations. +dataset = d4rl.qlearning_dataset(env) +``` + +Datasets are automatically downloaded to the `~/.d4rl/datasets` directory when `get_dataset()` is called. If you would like to change the location of this directory, you can set the `$D4RL_DATASET_DIR` environment variable to the directory of your choosing, or pass in the dataset filepath directly into the `get_dataset` method. + +### Normalizing Scores +You can use the `env.get_normalized_score(returns)` function to compute a normalized score for an episode, where `returns` is the undiscounted total sum of rewards accumulated during an episode. + +The individual min and max reference scores are stored in `d4rl/infos.py` for reference. + +## Algorithm Implementations + +We have aggregated implementations of various offline RL algorithms in a [separate repository](https://github.com/rail-berkeley/d4rl_evaluations). + +## Off-Policy Evaluations + +D4RL currently has limited support for off-policy evaluation methods, on a select few locomotion tasks. We provide trained reference policies and a set of performance metrics. Additional details can be found in the [wiki](https://github.com/rail-berkeley/d4rl/wiki/Off-Policy-Evaluation). + +## Recent Updates + +### 2-12-2020 +- Added new Gym-MuJoCo datasets (labeled v2) which fixed Hopper's performance and the qpos/qvel fields. +- Added additional wiki documentation on [generating datasets](https://github.com/rail-berkeley/d4rl/wiki/Dataset-Reproducibility-Guide). + + +## Acknowledgements + +D4RL builds on top of several excellent domains and environments built by various researchers. We would like to thank the authors of: +- [hand_dapg](https://github.com/aravindr93/hand_dapg) +- [gym-minigrid](https://github.com/maximecb/gym-minigrid) +- [carla](https://github.com/carla-simulator/carla) +- [flow](https://github.com/flow-project/flow) +- [adept_envs](https://github.com/google-research/relay-policy-learning) + +## Citation + +Please use the following bibtex for citations: + +``` +@misc{fu2020d4rl, + title={D4RL: Datasets for Deep Data-Driven Reinforcement Learning}, + author={Justin Fu and Aviral Kumar and Ofir Nachum and George Tucker and Sergey Levine}, + year={2020}, + eprint={2004.07219}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +## Licenses + +Unless otherwise noted, all datasets are licensed under the [Creative Commons Attribution 4.0 License (CC BY)](https://creativecommons.org/licenses/by/4.0/), and code is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0.html). + + diff --git a/d4rl/d4rl/__init__.py b/d4rl/d4rl/__init__.py new file mode 100644 index 0000000..6ad9b47 --- /dev/null +++ b/d4rl/d4rl/__init__.py @@ -0,0 +1,186 @@ +import os +import sys +import collections +import numpy as np + +import d4rl.infos +from d4rl.offline_env import set_dataset_path, get_keys + +SUPPRESS_MESSAGES = bool(os.environ.get('D4RL_SUPPRESS_IMPORT_ERROR', 0)) + +_ERROR_MESSAGE = 'Warning: %s failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.' + +try: + import d4rl.locomotion + import d4rl.hand_manipulation_suite + import d4rl.pointmaze + import d4rl.gym_minigrid + import d4rl.gym_mujoco +except ImportError as e: + if not SUPPRESS_MESSAGES: + print(_ERROR_MESSAGE % 'Mujoco-based envs', file=sys.stderr) + print(e, file=sys.stderr) + +try: + import d4rl.flow +except ImportError as e: + if not SUPPRESS_MESSAGES: + print(_ERROR_MESSAGE % 'Flow', file=sys.stderr) + print(e, file=sys.stderr) + +try: + import d4rl.kitchen +except ImportError as e: + if not SUPPRESS_MESSAGES: + print(_ERROR_MESSAGE % 'FrankaKitchen', file=sys.stderr) + print(e, file=sys.stderr) + +try: + import d4rl.carla +except ImportError as e: + if not SUPPRESS_MESSAGES: + print(_ERROR_MESSAGE % 'CARLA', file=sys.stderr) + print(e, file=sys.stderr) + +try: + import d4rl.gym_bullet + import d4rl.pointmaze_bullet +except ImportError as e: + if not SUPPRESS_MESSAGES: + print(_ERROR_MESSAGE % 'GymBullet', file=sys.stderr) + print(e, file=sys.stderr) + +def reverse_normalized_score(env_name, score): + ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] + ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] + return (score * (ref_max_score - ref_min_score)) + ref_min_score + +def get_normalized_score(env_name, score): + ref_min_score = d4rl.infos.REF_MIN_SCORE[env_name] + ref_max_score = d4rl.infos.REF_MAX_SCORE[env_name] + return (score - ref_min_score) / (ref_max_score - ref_min_score) + +def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs): + """ + Returns datasets formatted for use by standard Q-learning algorithms, + with observations, actions, next_observations, rewards, and a terminal + flag. + + Args: + env: An OfflineEnv object. + dataset: An optional dataset to pass in for processing. If None, + the dataset will default to env.get_dataset() + terminate_on_end (bool): Set done=True on the last timestep + in a trajectory. Default is False, and will discard the + last timestep in each trajectory. + **kwargs: Arguments to pass to env.get_dataset(). + + Returns: + A dictionary containing keys: + observations: An N x dim_obs array of observations. + actions: An N x dim_action array of actions. + next_observations: An N x dim_obs array of next observations. + rewards: An N-dim float array of rewards. + terminals: An N-dim boolean array of "done" or episode termination flags. + """ + if dataset is None: + dataset = env.get_dataset(**kwargs) + + N = dataset['rewards'].shape[0] + obs_ = [] + next_obs_ = [] + action_ = [] + reward_ = [] + done_ = [] + + # The newer version of the dataset adds an explicit + # timeouts field. Keep old method for backwards compatability. + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + for i in range(N-1): + obs = dataset['observations'][i].astype(np.float32) + new_obs = dataset['observations'][i+1].astype(np.float32) + action = dataset['actions'][i].astype(np.float32) + reward = dataset['rewards'][i].astype(np.float32) + done_bool = bool(dataset['terminals'][i]) + + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == env._max_episode_steps - 1) + if (not terminate_on_end) and final_timestep: + # Skip this transition and don't apply terminals on the last step of an episode + episode_step = 0 + continue + if done_bool or final_timestep: + episode_step = 0 + + obs_.append(obs) + next_obs_.append(new_obs) + action_.append(action) + reward_.append(reward) + done_.append(done_bool) + episode_step += 1 + + return { + 'observations': np.array(obs_), + 'actions': np.array(action_), + 'next_observations': np.array(next_obs_), + 'rewards': np.array(reward_), + 'terminals': np.array(done_), + } + + +def sequence_dataset(env, dataset=None, **kwargs): + """ + Returns an iterator through trajectories. + + Args: + env: An OfflineEnv object. + dataset: An optional dataset to pass in for processing. If None, + the dataset will default to env.get_dataset() + **kwargs: Arguments to pass to env.get_dataset(). + + Returns: + An iterator through dictionaries with keys: + observations + actions + rewards + terminals + """ + if dataset is None: + dataset = env.get_dataset(**kwargs) + + N = dataset['rewards'].shape[0] + data_ = collections.defaultdict(list) + + # The newer version of the dataset adds an explicit + # timeouts field. Keep old method for backwards compatability. + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + for i in range(N): + done_bool = bool(dataset['terminals'][i]) + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == env._max_episode_steps - 1) + + for k in dataset: + data_[k].append(dataset[k][i]) + + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + yield episode_data + data_ = collections.defaultdict(list) + + episode_step += 1 + diff --git a/d4rl/d4rl/carla/__init__.py b/d4rl/d4rl/carla/__init__.py new file mode 100644 index 0000000..2f34cb0 --- /dev/null +++ b/d4rl/d4rl/carla/__init__.py @@ -0,0 +1,126 @@ +from .carla_env import CarlaObsDictEnv +from .carla_env import CarlaObsEnv +from gym.envs.registration import register + + +register( + id='carla-lane-v0', + entry_point='d4rl.carla:CarlaObsEnv', + max_episode_steps=250, + kwargs={ + 'ref_min_score': -0.8503839912088142, + 'ref_max_score': 1023.5784385429523, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5', + 'reward_type': 'lane_follow', + 'carla_args': dict( + vision_size=48, + vision_fov=48, + weather=False, + frame_skip=1, + steps=250, + multiagent=True, + lane=0, + lights=False, + record_dir="None", + ) + } +) + + +register( + id='carla-lane-render-v0', + entry_point='d4rl.carla:CarlaDictEnv', + max_episode_steps=250, + kwargs={ + 'ref_min_score': -0.8503839912088142, + 'ref_max_score': 1023.5784385429523, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow-v0.hdf5', + 'reward_type': 'lane_follow', + 'render_images': True, + 'carla_args': dict( + vision_size=48, + vision_fov=48, + weather=False, + frame_skip=1, + steps=250, + multiagent=True, + lane=0, + lights=False, + record_dir="None", + ) + } +) + + +TOWN_STEPS = 1000 +register( + id='carla-town-v0', + entry_point='d4rl.carla:CarlaObsEnv', + max_episode_steps=TOWN_STEPS, + kwargs={ + 'ref_min_score': -114.81579500772153, # Average random returns + 'ref_max_score': 2440.1772022247314, # Average dataset returns + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5', + 'reward_type': 'goal_reaching', + 'carla_args': dict( + vision_size=48, + vision_fov=48, + weather=False, + frame_skip=1, + steps=TOWN_STEPS, + multiagent=True, + lane=0, + lights=False, + record_dir="None", + ) + } +) + + +register( + id='carla-town-full-v0', + entry_point='d4rl.carla:CarlaObsEnv', + max_episode_steps=TOWN_STEPS, + kwargs={ + 'ref_min_score': -114.81579500772153, # Average random returns + 'ref_max_score': 2440.1772022247314, # Average dataset returns + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', + 'reward_type': 'goal_reaching', + 'carla_args': dict( + vision_size=48, + vision_fov=48, + weather=False, + frame_skip=1, + steps=TOWN_STEPS, + multiagent=True, + lane=0, + lights=False, + record_dir="None", + ) + } +) + +register( + id='carla-town-render-v0', + entry_point='d4rl.carla:CarlaObsEnv', + max_episode_steps=TOWN_STEPS, + kwargs={ + 'ref_min_score': None, + 'ref_max_score': None, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', + 'render_images': True, + 'reward_type': 'goal_reaching', + 'carla_args': dict( + vision_size=48, + vision_fov=48, + weather=False, + frame_skip=1, + steps=TOWN_STEPS, + multiagent=True, + lane=0, + lights=False, + record_dir="None", + ) + } +) + diff --git a/d4rl/d4rl/carla/carla_env.py b/d4rl/d4rl/carla/carla_env.py new file mode 100644 index 0000000..9f60f68 --- /dev/null +++ b/d4rl/d4rl/carla/carla_env.py @@ -0,0 +1,1130 @@ +import argparse +import datetime +import glob +import os +import random +import sys +import time +from PIL import Image +from PIL.PngImagePlugin import PngInfo +import gym +from gym import Env +import gym.spaces as spaces + +#from . import proxy_env +from d4rl.offline_env import OfflineEnv + +try: + sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( + sys.version_info.major, + sys.version_info.minor, + 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) +except IndexError: + pass + +import carla +import math + +from dotmap import DotMap + +try: + import pygame +except ImportError: + raise RuntimeError('cannot import pygame, make sure pygame package is installed') + +try: + import numpy as np +except ImportError: + raise RuntimeError('cannot import numpy, make sure numpy package is installed') + +try: + import queue +except ImportError: + import Queue as queue + +# This is CARLA agent +from agents.navigation.agent import Agent, AgentState +from agents.navigation.local_planner import LocalPlanner +from agents.navigation.global_route_planner import GlobalRoutePlanner +from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO +from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle + +def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): + """ + Check if a target object is within a certain distance from a reference object. + A vehicle in front would be something around 0 deg, while one behind around 180 deg. + :param target_location: location of the target object + :param current_location: location of the reference object + :param orientation: orientation of the reference object + :param max_distance: maximum allowed distance + :param d_angle_th_up: upper thereshold for angle + :param d_angle_th_low: low thereshold for angle (optional, default is 0) + :return: True if target object is within max_distance ahead of the reference object + """ + target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) + norm_target = np.linalg.norm(target_vector) + + # If the vector is too short, we can simply stop here + if norm_target < 0.001: + return True + + if norm_target > max_distance: + return False + + forward_vector = np.array( + [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) + d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) + + return d_angle_th_low < d_angle < d_angle_th_up + +def compute_distance(location_1, location_2): + """ + Euclidean distance between 3D po-0.427844-0.427844ints + :param location_1, location_2: 3D points + """ + x = location_2.x - location_1.x + y = location_2.y - location_1.y + z = location_2.z - location_1.z + norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps + return norm + + +class CustomGlobalRoutePlanner(GlobalRoutePlanner): + def __init__(self, dao): + super(CustomGlobalRoutePlanner, self).__init__(dao=dao) + + def compute_direction_velocities(self, origin, velocity, destination): + node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) + + origin_xy = np.array([origin.x, origin.y]) + velocity_xy = np.array([velocity.x, velocity.y]) + first_node_xy = self._graph.nodes[node_list[0]]['vertex'] + first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) + target_direction_vector = first_node_xy - origin_xy + target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) + + vel_s = np.dot(velocity_xy, target_unit_vector) + + unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8) + angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0)) + vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle) + return vel_s, vel_perp + + def compute_distance(self, origin, destination): + node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) + #print('Node list:', node_list) + first_node_xy = self._graph.nodes[node_list[1]]['vertex'] + #print('Diff:', origin, first_node_xy) + + #distance = 0.0 + distances = [] + distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy))) + + for idx in range(len(node_list) - 1): + distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) + #print('Distances:', distances) + #import pdb; pdb.set_trace() + return np.sum(distances) + + +class CarlaSyncMode(object): + """ + Context manager to synchronize output from different sensors. Synchronous + mode is enabled as long as we are inside this context + with CarlaSyncMode(world, sensors) as sync_mode: + while True: + data = sync_mode.tick(timeout=1.0) + """ + + def __init__(self, world, *sensors, **kwargs): + self.world = world + self.sensors = sensors + self.frame = None + self.delta_seconds = 1.0 / kwargs.get('fps', 20) + self._queues = [] + self._settings = None + + self.start() + + def start(self): + self._settings = self.world.get_settings() + self.frame = self.world.apply_settings(carla.WorldSettings( + no_rendering_mode=False, + synchronous_mode=True, + fixed_delta_seconds=self.delta_seconds)) + + def make_queue(register_event): + q = queue.Queue() + register_event(q.put) + self._queues.append(q) + + make_queue(self.world.on_tick) + for sensor in self.sensors: + make_queue(sensor.listen) + + def tick(self, timeout): + self.frame = self.world.tick() + data = [self._retrieve_data(q, timeout) for q in self._queues] + assert all(x.frame == self.frame for x in data) + return data + + def __exit__(self, *args, **kwargs): + self.world.apply_settings(self._settings) + + def _retrieve_data(self, sensor_queue, timeout): + while True: + data = sensor_queue.get(timeout=timeout) + if data.frame == self.frame: + return data + + +class Sun(object): + def __init__(self, azimuth, altitude): + self.azimuth = azimuth + self.altitude = altitude + self._t = 0.0 + + def tick(self, delta_seconds): + self._t += 0.008 * delta_seconds + self._t %= 2.0 * math.pi + self.azimuth += 0.25 * delta_seconds + self.azimuth %= 360.0 + min_alt, max_alt = [20, 90] + self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) + + def __str__(self): + return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) + + +class Storm(object): + def __init__(self, precipitation): + self._t = precipitation if precipitation > 0.0 else -50.0 + self._increasing = True + self.clouds = 0.0 + self.rain = 0.0 + self.wetness = 0.0 + self.puddles = 0.0 + self.wind = 0.0 + self.fog = 0.0 + + def tick(self, delta_seconds): + delta = (1.3 if self._increasing else -1.3) * delta_seconds + self._t = clamp(delta + self._t, -250.0, 100.0) + self.clouds = clamp(self._t + 40.0, 0.0, 90.0) + self.clouds = clamp(self._t + 40.0, 0.0, 60.0) + self.rain = clamp(self._t, 0.0, 80.0) + delay = -10.0 if self._increasing else 90.0 + self.puddles = clamp(self._t + delay, 0.0, 85.0) + self.wetness = clamp(self._t * 5, 0.0, 100.0) + self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 + self.fog = clamp(self._t - 10, 0.0, 30.0) + if self._t == -250.0: + self._increasing = True + if self._t == 100.0: + self._increasing = False + + def __str__(self): + return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) + + +class Weather(object): + def __init__(self, world, changing_weather_speed): + self.world = world + self.reset() + self.weather = world.get_weather() + self.changing_weather_speed = changing_weather_speed + self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) + self._storm = Storm(self.weather.precipitation) + + def reset(self): + weather_params = carla.WeatherParameters(sun_altitude_angle=90.) + self.world.set_weather(weather_params) + + def tick(self): + self._sun.tick(self.changing_weather_speed) + self._storm.tick(self.changing_weather_speed) + self.weather.cloudiness = self._storm.clouds + self.weather.precipitation = self._storm.rain + self.weather.precipitation_deposits = self._storm.puddles + self.weather.wind_intensity = self._storm.wind + self.weather.fog_density = self._storm.fog + self.weather.wetness = self._storm.wetness + self.weather.sun_azimuth_angle = self._sun.azimuth + self.weather.sun_altitude_angle = self._sun.altitude + self.world.set_weather(self.weather) + + def __str__(self): + return '%s %s' % (self._sun, self._storm) + +def clamp(value, minimum=0.0, maximum=100.0): + return max(minimum, min(value, maximum)) + +## Now the actual env +class CarlaEnv(object): + """ + CARLA agent, we will wrap this in a proxy env to get a gym env + """ + def __init__(self, render=False, carla_port=2000, record=False, record_dir=None, args=None, record_vision=False, reward_type='lane_follow', **kwargs): + self.render_display = render + self.record_display = record + print('[CarlaEnv] record_vision:', record_vision) + self.record_vision = record_vision + self.record_dir = record_dir + self.reward_type = reward_type + self.vision_size = args['vision_size'] + self.vision_fov = args['vision_fov'] + self.changing_weather_speed = float(args['weather']) + self.frame_skip = args['frame_skip'] + self.max_episode_steps = args['steps'] # DMC uses this + self.multiagent = args['multiagent'] + self.start_lane = args['lane'] + self.follow_traffic_lights = args['lights'] + if self.record_display: + assert self.render_display + + self.actor_list = [] + + if self.render_display: + pygame.init() + self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF) + self.font = get_font() + self.clock = pygame.time.Clock() + + self.client = carla.Client('localhost', carla_port) + self.client.set_timeout(2.0) + + self.world = self.client.get_world() + self.map = self.world.get_map() + + # tests specific to map 4: + if self.start_lane and self.map.name != "Town04": + raise NotImplementedError + + # remove old vehicles and sensors (in case they survived) + self.world.tick() + actor_list = self.world.get_actors() + for vehicle in actor_list.filter("*vehicle*"): + print("Warning: removing old vehicle") + vehicle.destroy() + for sensor in actor_list.filter("*sensor*"): + print("Warning: removing old sensor") + sensor.destroy() + + self.vehicle = None + self.vehicles_list = [] # their ids + self.reset_vehicle() # creates self.vehicle + self.actor_list.append(self.vehicle) + + blueprint_library = self.world.get_blueprint_library() + + if self.render_display: + self.camera_display = self.world.spawn_actor( + blueprint_library.find('sensor.camera.rgb'), + carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), + attach_to=self.vehicle) + self.actor_list.append(self.camera_display) + + bp = blueprint_library.find('sensor.camera.rgb') + bp.set_attribute('image_size_x', str(self.vision_size)) + bp.set_attribute('image_size_y', str(self.vision_size)) + bp.set_attribute('fov', str(self.vision_fov)) + location = carla.Location(x=1.6, z=1.7) + self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle) + self.actor_list.append(self.camera_vision) + + if self.record_display or self.record_vision: + if self.record_dir is None: + self.record_dir = "carla-{}-{}x{}-fov{}".format( + self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov) + if self.frame_skip > 1: + self.record_dir += '-{}'.format(self.frame_skip) + if self.changing_weather_speed > 0.0: + self.record_dir += '-weather' + if self.multiagent: + self.record_dir += '-mutiagent' + if self.follow_traffic_lights: + self.record_dir += '-lights' + self.record_dir += '-{}k'.format(self.max_episode_steps // 1000) + + now = datetime.datetime.now() + self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S") + os.mkdir(self.record_dir) + + if self.render_display: + self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20) + else: + self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20) + + # weather + self.weather = Weather(self.world, self.changing_weather_speed) + + # dummy variables, to match deep mind control's APIs + low = -1.0 + high = 1.0 + + self.action_space = spaces.Box(low=np.array((low, low)), high=np.array((high, high))) + + self.observation_space = DotMap() + self.observation_space.shape = (3, self.vision_size, self.vision_size) + self.observation_space.dtype = np.dtype(np.uint8) + self.reward_range = None + self.metadata = None + # self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32) + + self.horizon = self.max_episode_steps + self.image_shape = (3, self.vision_size, self.vision_size) + + # roaming carla agent + self.count = 0 + self.world.tick() + self.reset_init() + + self._proximity_threshold = 10.0 + self._traffic_light_threshold = 5.0 + self.actor_list = self.world.get_actors() + #for idx in range(len(self.actor_list)): + # print (idx, self.actor_list[idx]) + + # import ipdb; ipdb.set_trace() + self.vehicle_list = self.actor_list.filter("*vehicle*") + self.lights_list = self.actor_list.filter("*traffic_light*") + self.object_list = self.actor_list.filter("*traffic.*") + + # town nav + self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) + self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao) + self.route_planner.setup() + self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433) + + # roaming carla agent + # self.agent = None + # self.count = 0 + # self.world.tick() + self.reset() # creates self.agent + + + def reset_init(self): + self.reset_vehicle() + self.world.tick() + self.reset_other_vehicles() + self.world.tick() + + # + + self.count = 0 + + def reset(self): + #self.reset_vehicle() + #self.world.tick() + #self.reset_other_vehicles() + #self.world.tick() + #self.count = 0 + # get obs: + #for _ in range(5): + # self.world.tick() + #obs, _, _, _ = self.step() + + obs, _, done, _ = self.step() + + # keep resetting until vehicle is not collided + total_resets = 0 + while done: + self.reset_vehicle() + self.world.tick() + obs, _, done, _ = self.step() + total_resets += 1 + if total_resets > 10: + break + + return obs + + def reset_vehicle(self): + + if self.map.name == "Town04": + self.start_lane = -1 # np.random.choice([-1, -2, -3, -4]) # their positive values, not negative + start_x = 5. + vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90)) + else: + init_transforms = self.world.get_map().get_spawn_points() + vehicle_init_transform = random.choice(init_transforms) + #print('MyInitTransform', vehicle_init_transform) + + + if self.vehicle is None: # then create the ego vehicle + blueprint_library = self.world.get_blueprint_library() + vehicle_blueprint = blueprint_library.find('vehicle.audi.a2') + self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform) + + self.vehicle.set_transform(vehicle_init_transform) + self.vehicle.set_velocity(carla.Vector3D()) + self.vehicle.set_angular_velocity(carla.Vector3D()) + + def reset_other_vehicles(self): + if not self.multiagent: + return + + # clear out old vehicles + self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) + self.world.tick() + self.vehicles_list = [] + + traffic_manager = self.client.get_trafficmanager() + traffic_manager.set_global_distance_to_leading_vehicle(2.0) + traffic_manager.set_synchronous_mode(True) + blueprints = self.world.get_blueprint_library().filter('vehicle.*') + blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4] + + num_vehicles = 20 + if self.map.name == "Town04": + road_id = 47 + road_length = 117. + init_transforms = [] + for _ in range(num_vehicles): + lane_id = random.choice([-1, -2, -3, -4]) + vehicle_s = np.random.uniform(road_length) # length of road 47 + init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform) + else: + init_transforms = self.world.get_map().get_spawn_points() + init_transforms = np.random.choice(init_transforms, num_vehicles) + #print('OtherInitTransforms:') + #for transf in init_transforms: + # print(transf) + + # -------------- + # Spawn vehicles + # -------------- + batch = [] + for transform in init_transforms: + transform.location.z += 0.1 # otherwise can collide with the road it starts on + blueprint = random.choice(blueprints) + if blueprint.has_attribute('color'): + color = random.choice(blueprint.get_attribute('color').recommended_values) + blueprint.set_attribute('color', color) + if blueprint.has_attribute('driver_id'): + driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values) + blueprint.set_attribute('driver_id', driver_id) + blueprint.set_attribute('role_name', 'autopilot') + batch.append(carla.command.SpawnActor(blueprint, transform).then( + carla.command.SetAutopilot(carla.command.FutureActor, True))) + + for response in self.client.apply_batch_sync(batch, False): + self.vehicles_list.append(response.actor_id) + + for response in self.client.apply_batch_sync(batch): + if response.error: + pass + else: + self.vehicles_list.append(response.actor_id) + + traffic_manager.global_percentage_speed_difference(30.0) + + def step(self, action=None, traffic_light_color=""): + """ + rewards = [] + for _ in range(self.frame_skip): # default 1 + next_obs, reward, done, info = self._simulator_step(action, traffic_light_color) + rewards.append(reward) + if done: + break + return next_obs, np.mean(rewards), done, info + """ + return self._simulator_step(action, traffic_light_color) + + def _is_vehicle_hazard(self, vehicle, vehicle_list): + """ + :param vehicle_list: list of potential obstacle to check + :return: a tuple given by (bool_flag, vehicle), where + - bool_flag is True if there is a vehicle ahead blocking us + and False otherwise + - vehicle is the blocker object itself + """ + + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for target_vehicle in vehicle_list: + # do not account for the ego vehicle + if target_vehicle.id == vehicle.id: + continue + + # if the object is not in our lane it's not an obstacle + target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) + if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(target_vehicle.get_transform(), + vehicle.get_transform(), + self._proximity_threshold/10.0): + return (True, -1.0, target_vehicle) + + return (False, 0.0, None) + + def _is_object_hazard(self, vehicle, object_list): + """ + :param vehicle_list: list of potential obstacle to check + :return: a tuple given by (bool_flag, vehicle), where + - bool_flag is True if there is a vehicle ahead blocking us + and False otherwise + - vehicle is the blocker object itself + """ + + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for target_vehicle in object_list: + # do not account for the ego vehicle + if target_vehicle.id == vehicle.id: + continue + + # if the object is not in our lane it's not an obstacle + target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) + if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(target_vehicle.get_transform(), + vehicle.get_transform(), + self._proximity_threshold/40.0): + return (True, -1.0, target_vehicle) + + return (False, 0.0, None) + + def _is_light_red(self, vehicle): + """ + Method to check if there is a red light affecting us. This version of + the method is compatible with both European and US style traffic lights. + :param lights_list: list containing TrafficLight objects + :return: a tuple given by (bool_flag, traffic_light), where + - bool_flag is True if there is a traffic light in RED + affecting us and False otherwise + - traffic_light is the object itself or None if there is no + red traffic light affecting us + """ + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for traffic_light in self.lights_list: + object_location = self._get_trafficlight_trigger_location(traffic_light) + object_waypoint = self.map.get_waypoint(object_location) + + if object_waypoint.road_id != ego_vehicle_waypoint.road_id: + continue + + ve_dir = ego_vehicle_waypoint.transform.get_forward_vector() + wp_dir = object_waypoint.transform.get_forward_vector() + dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z + + if dot_ve_wp < 0: + continue + + if is_within_distance_ahead(object_waypoint.transform, + vehicle.get_transform(), + self._traffic_light_threshold): + if traffic_light.state == carla.TrafficLightState.Red: + return (True, -0.1, traffic_light) + + return (False, 0.0, None) + + def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use + """ + Calculates the yaw of the waypoint that represents the trigger volume of the traffic light + """ + def rotate_point(point, radians): + """ + rotate a given point by a given angle + """ + rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y + rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y + + return carla.Vector3D(rotated_x, rotated_y, point.z) + + base_transform = traffic_light.get_transform() + base_rot = base_transform.rotation.yaw + area_loc = base_transform.transform(traffic_light.trigger_volume.location) + area_ext = traffic_light.trigger_volume.extent + + point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot)) + point_location = area_loc + carla.Location(x=point.x, y=point.y) + + return carla.Location(point_location.x, point_location.y, point_location.z) + + def _get_collision_reward(self, vehicle): + vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list) + + # Check the lane ids + loc = vehicle.get_location() + if loc is not None: + w = self.map.get_waypoint(loc) + if w is not None: + current_lane_id = w.lane_id + if current_lane_id not in [-1, 1]: + #print ('Lane: ', current_lane_id, self.start_lane) + vehicle_hazard = True + reward = -1.0 + else: + vehicle_hazard = True + reward = -1.0 + else: + vehicle_hazard = True + reward = -1.0 + + #print ('vehicle: ', loc, current_lane_id, self.start_lane) + return vehicle_hazard, reward + + def _get_traffic_light_reward(self, vehicle): + traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle) + return traffic_light_hazard, 0.0 + + def _get_object_collided_reward(self, vehicle): + object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list) + return object_hazard, reward + + def goal_reaching_reward(self, vehicle): + # Now we will write goal_reaching_rewards + vehicle_location = vehicle.get_location() + vehicle_velocity = vehicle.get_velocity() + + target_location = self.target_location + + # This is the distance computation + try: + dist = self.route_planner.compute_distance(vehicle_location, target_location) + vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location) + except TypeError: + # Weird bug where the graph disappears + vel_forward = 0 + vel_perp = 0 + + #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward))) + + #base_reward = -1.0 * (dist / 100.0) + 5.0 + base_reward = vel_forward + collided_done, collision_reward = self._get_collision_reward(vehicle) + traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) + object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) + total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward + reward_dict = dict() + reward_dict['collision'] = collision_reward + reward_dict['traffic_light'] = traffic_light_reward + reward_dict['object_collision'] = object_collided_reward + reward_dict['base_reward'] = base_reward + done_dict = dict() + done_dict['collided_done'] = collided_done + done_dict['traffic_light_done'] = traffic_light_done + done_dict['object_collided_done'] = object_collided_done + return total_reward, reward_dict, done_dict + + def lane_follow_reward(self, vehicle): + # assume on highway + vehicle_location = vehicle.get_location() + vehicle_waypoint = self.map.get_waypoint(vehicle_location) + vehicle_xy = np.array([vehicle_location.x, vehicle_location.y]) + vehicle_s = vehicle_waypoint.s + vehicle_velocity = vehicle.get_velocity() # Vector3D + vehicle_velocity_xy = np.array([vehicle_velocity.x, vehicle_velocity.y]) + # print ('Velocity: ', vehicle_velocity_xy) + speed = np.linalg.norm(vehicle_velocity_xy) + vehicle_waypoint_closest_to_road = \ + self.map.get_waypoint(vehicle_location, project_to_road=True, lane_type=carla.LaneType.Driving) + road_id = vehicle_waypoint_closest_to_road.road_id + assert road_id is not None + goal_abs_lane_id = 1 # just for goal-following + lane_id_sign = int(np.sign(vehicle_waypoint_closest_to_road.lane_id)) + assert lane_id_sign in [-1, 1] + goal_lane_id = goal_abs_lane_id * lane_id_sign + current_waypoint = self.map.get_waypoint(vehicle_location, project_to_road=False) + goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s) + + # Check for valid goal waypoint + if goal_waypoint is None: + print ('goal waypoint is None...') + # try to fix, bit of a hack, with CARLA waypoint discretizations + carla_waypoint_discretization = 0.02 # meters + goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s - carla_waypoint_discretization) + if goal_waypoint is None: + goal_waypoint = self.map.get_waypoint_xodr(road_id, goal_lane_id, vehicle_s + carla_waypoint_discretization) + + # set distance to 100 if the waypoint is off the road + if goal_waypoint is None: + print("Episode fail: goal waypoint is off the road! (frame %d)" % self.count) + done, dist, vel_s = True, 100., 0. + else: + goal_location = goal_waypoint.transform.location + goal_xy = np.array([goal_location.x, goal_location.y]) + # dist = np.linalg.norm(vehicle_xy - goal_xy) + dists = [] + for abs_lane_id in [1, 2, 3, 4]: + lane_id_ = abs_lane_id * lane_id_sign + wp = self.map.get_waypoint_xodr(road_id, lane_id_, vehicle_s) + if wp is not None: # lane 4 might not exist where the highway has a turnoff + loc = wp.transform.location + xy = np.array([loc.x, loc.y]) + dists.append(np.linalg.norm(vehicle_xy - xy)) + if dists: + dist = min(dists) # just try to get to the center of one of the lanes + else: + dist = 0. + next_goal_waypoint = goal_waypoint.next(0.1) # waypoints are ever 0.02 meters + if len(next_goal_waypoint) != 1: + print('warning: {} waypoints (not 1)'.format(len(next_goal_waypoint))) + if len(next_goal_waypoint) == 0: + print("Episode done: no more waypoints left. (frame %d)" % self.count) + done, vel_s, vel_perp = True, 0., 0. + else: + location_ahead = next_goal_waypoint[0].transform.location + highway_vector = np.array([location_ahead.x, location_ahead.y]) - goal_xy + highway_unit_vector = np.array(highway_vector) / np.linalg.norm(highway_vector) + vel_s = np.dot(vehicle_velocity_xy, highway_unit_vector) + + unit_velocity = vehicle_velocity_xy / (np.linalg.norm(vehicle_velocity_xy) + 1e-8) + angle = np.arccos(np.clip(np.dot(unit_velocity, highway_unit_vector), -1.0, 1.0)) + #vel_forward = np.linalg.norm(vehicle_velocity_xy) * np.cos(angle) + vel_perp = np.linalg.norm(vehicle_velocity_xy) * np.sin(angle) + #print('R:', np.clip(vel_s-5*vel_perp, -5.0, 5.0), 'vel_s:', vel_s, 'vel_perp:', vel_perp) + #import pdb; pdb.set_trace() + + done = False + + # not algorithm's fault, but the simulator sometimes throws the car in the air wierdly + # usually in initial few frames, which can be ignored + """ + if vehicle_velocity.z > 1. and self.count < 20: + print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_velocity.z, self.count)) + done = True + if vehicle_location.z > 0.5 and self.count < 20: + print("Episode done: vertical velocity too high ({}), usually a simulator glitch (frame {})".format(vehicle_location.z, self.count)) + done = True + """ + + ## Add rewards for collision and optionally traffic lights + vehicle_location = vehicle.get_location() + base_reward = np.clip(vel_s - 5*vel_perp, -5.0, 5.0) + collided_done, collision_reward = self._get_collision_reward(vehicle) + traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) + object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) + total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward + reward_dict = dict() + reward_dict['collision'] = collision_reward + reward_dict['traffic_light'] = traffic_light_reward + reward_dict['object_collision'] = object_collided_reward + reward_dict['base_reward'] = base_reward + reward_dict['base_reward_vel_s'] = vel_s + reward_dict['base_reward_vel_perp'] = vel_perp + done_dict = dict() + done_dict['collided_done'] = collided_done + done_dict['traffic_light_done'] = traffic_light_done + done_dict['object_collided_done'] = object_collided_done + done_dict['base_done'] = done + return total_reward, reward_dict, done_dict + + def _simulator_step(self, action, traffic_light_color): + + if action is None: + throttle, steer, brake = 0., 0., 0. + else: + steer = float(action[1]) + throttle_brake = float(action[0]) + + if throttle_brake >= 0.0: + throttle = throttle_brake + brake = 0.0 + else: + throttle = 0.0 + brake = -throttle_brake + + vehicle_control = carla.VehicleControl( + throttle=float(throttle), + steer=float(steer), + brake=float(brake), + hand_brake=False, + reverse=False, + manual_gear_shift=False + ) + self.vehicle.apply_control(vehicle_control) + + # Advance the simulation and wait for the data. + if self.render_display: + snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0) + else: + snapshot, vision_image = self.sync_mode.tick(timeout=2.0) + + # Weather evolves + self.weather.tick() + + # Draw the display. + if self.render_display: + self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10)) + self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28)) + self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46)) + self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64)) + pygame.display.flip() + + # Format rl image + bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format + bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3) + rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3) + + if self.render_display and self.record_display: + image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count) + pygame.image.save(self.render_display, image_name) + # # Can animate with: + # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4 + if self.record_vision: + image_name = os.path.join(self.record_dir, "vision%08d.png" % self.count) + print('savedimg:', image_name) + im = Image.fromarray(rgb) + + # add any meta data you like into the image before we save it: + metadata = PngInfo() + metadata.add_text("throttle", str(throttle)) + metadata.add_text("steer", str(steer)) + metadata.add_text("brake", str(brake)) + metadata.add_text("lights", traffic_light_color) + + # acceleration + acceleration = self.vehicle.get_acceleration() + metadata.add_text("acceleration_x", str(acceleration.x)) + metadata.add_text("acceleration_y", str(acceleration.y)) + metadata.add_text("acceleration_z", str(acceleration.z)) + # angular velocity + angular_velocity = self.vehicle.get_angular_velocity() + metadata.add_text("angular_velocity_x", str(angular_velocity.x)) + metadata.add_text("angular_velocity_y", str(angular_velocity.y)) + metadata.add_text("angular_velocity_z", str(angular_velocity.z)) + # location + location = self.vehicle.get_location() + metadata.add_text("location_x", str(location.x)) + metadata.add_text("location_y", str(location.y)) + metadata.add_text("location_z", str(location.z)) + # rotation + rotation = self.vehicle.get_transform().rotation + metadata.add_text("rotation_pitch", str(rotation.pitch)) + metadata.add_text("rotation_yaw", str(rotation.yaw)) + metadata.add_text("rotation_roll", str(rotation.roll)) + forward_vector = rotation.get_forward_vector() + metadata.add_text("forward_vector_x", str(forward_vector.x)) + metadata.add_text("forward_vector_y", str(forward_vector.y)) + metadata.add_text("forward_vector_z", str(forward_vector.z)) + # velocity + velocity = self.vehicle.get_velocity() + metadata.add_text("velocity_x", str(velocity.x)) + metadata.add_text("velocity_y", str(velocity.y)) + metadata.add_text("velocity_z", str(velocity.z)) + # weather + metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness)) + metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation)) + metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits)) + metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity)) + metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density)) + metadata.add_text("weather_wetness", str(self.weather.weather.wetness)) + metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle)) + # settings + metadata.add_text("settings_map", self.map.name) + metadata.add_text("settings_vision_size", str(self.vision_size)) + metadata.add_text("settings_vision_fov", str(self.vision_fov)) + metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed)) + metadata.add_text("settings_multiagent", str(self.multiagent)) + # traffic lights + metadata.add_text("traffic_lights_color", "UNLABELED") + metadata.add_text("reward", str(reward)) + + ## Add in reward dict + for key in reward_dict: + metadata.add_text("reward_" + str(key), str(reward_dict[key])) + + for key in done_dict: + metadata.add_text("done_" + str(key), str(done_dict[key])) + + ## Save the target location as well + metadata.add_text('target_location_x', str(self.target_location.x)) + metadata.add_text('target_location_y', str(self.target_location.y)) + metadata.add_text('target_location_z', str(self.target_location.z)) + + im.save(image_name, "PNG", pnginfo=metadata) + + self.count += 1 + + next_obs = rgb + + done = False + if done: + print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps)) + + if self.reward_type=='lane_follow': + reward, reward_dict, done_dict = self.lane_follow_reward(self.vehicle) + elif self.reward_type=='goal_reaching': + reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle) + else: + raise ValueError('unknown reward type:', self.reward_type) + + info = reward_dict + info.update(done_dict) + done = False + for key in done_dict: + done = (done or done_dict[key]) + #if done: + # print('done_dict:', done_dict, 'r:', reward) + return next_obs, reward, done, info + + def finish(self): + print('destroying actors.') + for actor in self.actor_list: + actor.destroy() + print('\ndestroying %d vehicles' % len(self.vehicles_list)) + self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) + time.sleep(0.5) + pygame.quit() + print('done.') + + +class CarlaObsDictEnv(OfflineEnv): + def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs): + self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images) + print('[CarlaObsDictEnv] render_images:', render_images) + self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, record_vision=render_images) + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + + self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape)) + + self.observation_space = spaces.Dict({ + 'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) + }) + print (self.observation_space) + super(CarlaObsDictEnv, self).__init__(**kwargs) + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + self._wrapped_env.reset_init() + obs = (self._wrapped_env.reset(**kwargs)) + obs_dict = dict() + # Also normalize obs + obs_dict['image'] = (obs.astype(np.float32) / 255.0).flatten() + return obs_dict + + def step(self, action): + #print ('Action: ', action) + next_obs, reward, done, info = self._wrapped_env.step(action) + next_obs_dict = dict() + next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten() + # print ('Reward: ', reward) + # print ('Done dict: ', info) + return next_obs_dict, reward, done, info + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self._wrapped_env.terminate() + + def __getattr__(self, attr): + if attr == '_wrapped_env': + raise AttributeError() + return getattr(self._wrapped_env, attr) + + def __getstate__(self): + """ + This is useful to override in case the wrapped env has some funky + __getstate__ that doesn't play well with overriding __getattr__. + + The main problematic case is/was gym's EzPickle serialization scheme. + :return: + """ + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return '{}({})'.format(type(self).__name__, self.wrapped_env) + + +class CarlaObsEnv(OfflineEnv): + def __init__(self, carla_args=None, carla_port=2000, reward_type='lane_follow', render_images=False, **kwargs): + self._wrapped_env = CarlaEnv(carla_port=carla_port, args=carla_args, reward_type=reward_type, record_vision=render_images) + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + self.observation_size = int(np.prod(self._wrapped_env.observation_space.shape)) + self.observation_space = spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) + #self.observation_space = spaces.Dict({ + # 'image':spaces.Box(low=np.array([0.0] * self.observation_size), high=np.array([256.0,] * self.observation_size)) + #}) + super(CarlaObsEnv, self).__init__(**kwargs) + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + self._wrapped_env.reset_init() + obs = (self._wrapped_env.reset(**kwargs)) + obs_dict = dict() + # Also normalize obs + obs_dict = (obs.astype(np.float32) / 255.0).flatten() + return obs_dict + + def step(self, action): + #print ('Action: ', action) + next_obs, reward, done, info = self._wrapped_env.step(action) + #next_obs_dict = dict() + #next_obs_dict['image'] = (next_obs.astype(np.float32) / 255.0).flatten() + next_obs_dict = (next_obs.astype(np.float32) / 255.0).flatten() + # print ('Reward: ', reward) + # print ('Done dict: ', info) + return next_obs_dict, reward, done, info + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self._wrapped_env.terminate() + + def __getattr__(self, attr): + if attr == '_wrapped_env': + raise AttributeError() + return getattr(self._wrapped_env, attr) + + def __getstate__(self): + """ + This is useful to override in case the wrapped env has some funky + __getstate__ that doesn't play well with overriding __getattr__. + + The main problematic case is/was gym's EzPickle serialization scheme. + :return: + """ + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return '{}({})'.format(type(self).__name__, self.wrapped_env) + +if __name__ == '__main__': + variant = dict() + variant['vision_size'] = 48 + variant['vision_fov'] = 48 + variant['weather'] = False + variant['frame_skip'] = 1 + variant['steps'] = 100000 + variant['multiagent'] = False + variant['lane'] = 0 + variant['lights'] = False + variant['record_dir'] = None + + env = CarlaEnv(args=variant) + carla_gym_env = proxy_env.ProxyEnv(env) diff --git a/d4rl/d4rl/carla/data_collection_agent_lane.py b/d4rl/d4rl/carla/data_collection_agent_lane.py new file mode 100644 index 0000000..2187970 --- /dev/null +++ b/d4rl/d4rl/carla/data_collection_agent_lane.py @@ -0,0 +1,461 @@ +# !/usr/bin/env python + +# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de +# Barcelona (UAB). +# +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +# Modified by Rowan McAllister on 20 April 2020 + +import argparse +import datetime +import glob +import os +import random +import sys +import time +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +try: + sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( + sys.version_info.major, + sys.version_info.minor, + 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) +except IndexError: + pass + +import carla +import math + +from dotmap import DotMap + +try: + import pygame +except ImportError: + raise RuntimeError('cannot import pygame, make sure pygame package is installed') + +try: + import numpy as np +except ImportError: + raise RuntimeError('cannot import numpy, make sure numpy package is installed') + +try: + import queue +except ImportError: + import Queue as queue + +from agents.navigation.agent import Agent, AgentState +from agents.navigation.local_planner import LocalPlanner +from agents.navigation.global_route_planner import GlobalRoutePlanner +from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle +from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO + + +def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): + """ + Check if a target object is within a certain distance from a reference object. + A vehicle in front would be something around 0 deg, while one behind around 180 deg. + :param target_location: location of the target object + :param current_location: location of the reference object + :param orientation: orientation of the reference object + :param max_distance: maximum allowed distance + :param d_angle_th_up: upper thereshold for angle + :param d_angle_th_low: low thereshold for angle (optional, default is 0) + :return: True if target object is within max_distance ahead of the reference object + """ + target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) + norm_target = np.linalg.norm(target_vector) + + # If the vector is too short, we can simply stop here + if norm_target < 0.001: + return True + + if norm_target > max_distance: + return False + + forward_vector = np.array( + [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) + d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) + + return d_angle_th_low < d_angle < d_angle_th_up + + +def compute_distance(location_1, location_2): + """ + Euclidean distance between 3D points + :param location_1, location_2: 3D points + """ + x = location_2.x - location_1.x + y = location_2.y - location_1.y + z = location_2.z - location_1.z + norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps + return norm + + + +class CarlaSyncMode(object): + """ + Context manager to synchronize output from different sensors. Synchronous + mode is enabled as long as we are inside this context + + with CarlaSyncMode(world, sensors) as sync_mode: + while True: + data = sync_mode.tick(timeout=1.0) + + """ + + def __init__(self, world, *sensors, **kwargs): + self.world = world + self.sensors = sensors + self.frame = None + self.delta_seconds = 1.0 / kwargs.get('fps', 20) + self._queues = [] + self._settings = None + + self.start() + + def start(self): + self._settings = self.world.get_settings() + self.frame = self.world.apply_settings(carla.WorldSettings( + no_rendering_mode=False, + synchronous_mode=True, + fixed_delta_seconds=self.delta_seconds)) + + def make_queue(register_event): + q = queue.Queue() + register_event(q.put) + self._queues.append(q) + + make_queue(self.world.on_tick) + for sensor in self.sensors: + make_queue(sensor.listen) + + def tick(self, timeout): + self.frame = self.world.tick() + data = [self._retrieve_data(q, timeout) for q in self._queues] + assert all(x.frame == self.frame for x in data) + return data + + def __exit__(self, *args, **kwargs): + self.world.apply_settings(self._settings) + + def _retrieve_data(self, sensor_queue, timeout): + while True: + data = sensor_queue.get(timeout=timeout) + if data.frame == self.frame: + return data + + +def draw_image(surface, image, blend=False): + array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8")) + array = np.reshape(array, (image.height, image.width, 4)) + array = array[:, :, :3] + array = array[:, :, ::-1] + image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1)) + if blend: + image_surface.set_alpha(100) + surface.blit(image_surface, (0, 0)) + + +def get_font(): + fonts = [x for x in pygame.font.get_fonts()] + default_font = 'ubuntumono' + font = default_font if default_font in fonts else fonts[0] + font = pygame.font.match_font(font) + return pygame.font.Font(font, 14) + + +def should_quit(): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return True + elif event.type == pygame.KEYUP: + if event.key == pygame.K_ESCAPE: + return True + return False + + +def clamp(value, minimum=0.0, maximum=100.0): + return max(minimum, min(value, maximum)) + + +class Sun(object): + def __init__(self, azimuth, altitude): + self.azimuth = azimuth + self.altitude = altitude + self._t = 0.0 + + def tick(self, delta_seconds): + self._t += 0.008 * delta_seconds + self._t %= 2.0 * math.pi + self.azimuth += 0.25 * delta_seconds + self.azimuth %= 360.0 + min_alt, max_alt = [20, 90] + self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) + + def __str__(self): + return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) + + +class Storm(object): + def __init__(self, precipitation): + self._t = precipitation if precipitation > 0.0 else -50.0 + self._increasing = True + self.clouds = 0.0 + self.rain = 0.0 + self.wetness = 0.0 + self.puddles = 0.0 + self.wind = 0.0 + self.fog = 0.0 + + def tick(self, delta_seconds): + delta = (1.3 if self._increasing else -1.3) * delta_seconds + self._t = clamp(delta + self._t, -250.0, 100.0) + self.clouds = clamp(self._t + 40.0, 0.0, 90.0) + self.clouds = clamp(self._t + 40.0, 0.0, 60.0) + self.rain = clamp(self._t, 0.0, 80.0) + delay = -10.0 if self._increasing else 90.0 + self.puddles = clamp(self._t + delay, 0.0, 85.0) + self.wetness = clamp(self._t * 5, 0.0, 100.0) + self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 + self.fog = clamp(self._t - 10, 0.0, 30.0) + if self._t == -250.0: + self._increasing = True + if self._t == 100.0: + self._increasing = False + + def __str__(self): + return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) + + +class Weather(object): + def __init__(self, world, changing_weather_speed): + self.world = world + self.reset() + self.weather = world.get_weather() + self.changing_weather_speed = changing_weather_speed + self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) + self._storm = Storm(self.weather.precipitation) + + def reset(self): + weather_params = carla.WeatherParameters(sun_altitude_angle=90.) + self.world.set_weather(weather_params) + + def tick(self): + self._sun.tick(self.changing_weather_speed) + self._storm.tick(self.changing_weather_speed) + self.weather.cloudiness = self._storm.clouds + self.weather.precipitation = self._storm.rain + self.weather.precipitation_deposits = self._storm.puddles + self.weather.wind_intensity = self._storm.wind + self.weather.fog_density = self._storm.fog + self.weather.wetness = self._storm.wetness + self.weather.sun_azimuth_angle = self._sun.azimuth + self.weather.sun_altitude_angle = self._sun.altitude + self.world.set_weather(self.weather) + + def __str__(self): + return '%s %s' % (self._sun, self._storm) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--vision_size', type=int, default=84) + parser.add_argument('--vision_fov', type=int, default=90) + parser.add_argument('--weather', default=False, action='store_true') + parser.add_argument('--frame_skip', type=int, default=1), + parser.add_argument('--steps', type=int, default=100000) + parser.add_argument('--multiagent', default=False, action='store_true'), + parser.add_argument('--lane', type=int, default=0) + parser.add_argument('--lights', default=False, action='store_true') + args = parser.parse_args() + return args + + +class LocalPlannerModified(LocalPlanner): + + def __del__(self): + pass # otherwise it deletes our vehicle object + + def run_step(self): + return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera + + +class RoamingAgent(Agent): + """ + RoamingAgent implements a basic agent that navigates scenes making random + choices when facing an intersection. + + This agent respects traffic lights and other vehicles. + + NOTE: need to re-create after each env reset + """ + + def __init__(self, env): + """ + + :param vehicle: actor to apply to local planner logic onto + """ + vehicle = env.vehicle + follow_traffic_lights = env.follow_traffic_lights + super(RoamingAgent, self).__init__(vehicle) + self._proximity_threshold = 10.0 # meters + self._state = AgentState.NAVIGATING + self._local_planner = LocalPlannerModified(self._vehicle) + self._follow_traffic_lights = follow_traffic_lights + + def compute_action(self): + action, traffic_light = self.run_step() + throttle = action.throttle + brake = action.brake + steer = action.steer + #print('tbsl:', throttle, brake, steer, traffic_light) + if brake == 0.0: + return np.array([throttle, steer]) + else: + return np.array([-brake, steer]) + + def run_step(self): + """ + Execute one step of navigation. + :return: carla.VehicleControl + """ + + # is there an obstacle in front of us? + hazard_detected = False + + # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles + actor_list = self._world.get_actors() + vehicle_list = actor_list.filter("*vehicle*") + lights_list = actor_list.filter("*traffic_light*") + + # check possible obstacles + vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) + if vehicle_state: + + self._state = AgentState.BLOCKED_BY_VEHICLE + hazard_detected = True + + # check for the state of the traffic lights + traffic_light_color = self._is_light_red(lights_list) + if traffic_light_color == 'RED' and self._follow_traffic_lights: + self._state = AgentState.BLOCKED_RED_LIGHT + hazard_detected = True + + if hazard_detected: + control = self.emergency_stop() + else: + self._state = AgentState.NAVIGATING + # standard local planner behavior + control = self._local_planner.run_step() + + #print ('Action chosen: ', control) + return control, traffic_light_color + + # override case class + def _is_light_red_europe_style(self, lights_list): + """ + This method is specialized to check European style traffic lights. + Only suitable for Towns 03 -- 07. + """ + ego_vehicle_location = self._vehicle.get_location() + ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) + + traffic_light_color = "NONE" # default, if no traffic lights are seen + + for traffic_light in lights_list: + object_waypoint = self._map.get_waypoint(traffic_light.get_location()) + if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + object_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(traffic_light.get_transform(), + self._vehicle.get_transform(), + self._proximity_threshold): + if traffic_light.state == carla.TrafficLightState.Red: + return "RED" + elif traffic_light.state == carla.TrafficLightState.Yellow: + traffic_light_color = "YELLOW" + elif traffic_light.state == carla.TrafficLightState.Green: + if traffic_light_color is not "YELLOW": # (more severe) + traffic_light_color = "GREEN" + else: + import pdb; pdb.set_trace() + # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate + + return traffic_light_color + + # override case class + def _is_light_red_us_style(self, lights_list, debug=False): + ego_vehicle_location = self._vehicle.get_location() + ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) + + traffic_light_color = "NONE" # default, if no traffic lights are seen + + if ego_vehicle_waypoint.is_junction: + # It is too late. Do not block the intersection! Keep going! + return "JUNCTION" + + if self._local_planner.target_waypoint is not None: + if self._local_planner.target_waypoint.is_junction: + min_angle = 180.0 + sel_magnitude = 0.0 + sel_traffic_light = None + for traffic_light in lights_list: + loc = traffic_light.get_location() + magnitude, angle = compute_magnitude_angle(loc, + ego_vehicle_location, + self._vehicle.get_transform().rotation.yaw) + if magnitude < 60.0 and angle < min(25.0, min_angle): + sel_magnitude = magnitude + sel_traffic_light = traffic_light + min_angle = angle + + if sel_traffic_light is not None: + if debug: + print('=== Magnitude = {} | Angle = {} | ID = {}'.format( + sel_magnitude, min_angle, sel_traffic_light.id)) + + if self._last_traffic_light is None: + self._last_traffic_light = sel_traffic_light + + if self._last_traffic_light.state == carla.TrafficLightState.Red: + return "RED" + elif self._last_traffic_light.state == carla.TrafficLightState.Yellow: + traffic_light_color = "YELLOW" + elif self._last_traffic_light.state == carla.TrafficLightState.Green: + if traffic_light_color is not "YELLOW": # (more severe) + traffic_light_color = "GREEN" + else: + import pdb; pdb.set_trace() + # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate + else: + self._last_traffic_light = None + + return traffic_light_color + + +if __name__ == '__main__': + + # example call: + # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05 + # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights + + args = parse_args() + env = CarlaEnv(args) + + try: + done = False + while not done: + action, traffic_light_color = env.compute_action() + next_obs, reward, done, info = env.step(action, traffic_light_color) + print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location()) + if done: + # env.reset_init() + # env.reset() + done = False + + finally: + env.finish() diff --git a/d4rl/d4rl/carla/data_collection_town.py b/d4rl/d4rl/carla/data_collection_town.py new file mode 100644 index 0000000..e64287f --- /dev/null +++ b/d4rl/d4rl/carla/data_collection_town.py @@ -0,0 +1,1084 @@ +#!/usr/bin/env python + +# Copyright (c) 2019 Computer Vision Center (CVC) at the Universitat Autonoma de +# Barcelona (UAB). +# +# This work is licensed under the terms of the MIT license. +# For a copy, see . +# +# Modified by Rowan McAllister on 20 April 2020 + +import argparse +import datetime +import glob +import os +import random +import sys +import time +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +try: + sys.path.append(glob.glob('../carla/dist/carla-*%d.%d-%s.egg' % ( + sys.version_info.major, + sys.version_info.minor, + 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) +except IndexError: + pass + +import carla +import math + +from dotmap import DotMap + +try: + import pygame +except ImportError: + raise RuntimeError('cannot import pygame, make sure pygame package is installed') + +try: + import numpy as np +except ImportError: + raise RuntimeError('cannot import numpy, make sure numpy package is installed') + +try: + import queue +except ImportError: + import Queue as queue + +from agents.navigation.agent import Agent, AgentState +from agents.navigation.local_planner import LocalPlanner +from agents.navigation.global_route_planner import GlobalRoutePlanner +from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO +from agents.tools.misc import is_within_distance_ahead #, is_within_distance, compute_distance +from agents.tools.misc import is_within_distance_ahead, compute_magnitude_angle + +def is_within_distance(target_location, current_location, orientation, max_distance, d_angle_th_up, d_angle_th_low=0): + """ + Check if a target object is within a certain distance from a reference object. + A vehicle in front would be something around 0 deg, while one behind around 180 deg. + :param target_location: location of the target object + :param current_location: location of the reference object + :param orientation: orientation of the reference object + :param max_distance: maximum allowed distance + :param d_angle_th_up: upper thereshold for angle + :param d_angle_th_low: low thereshold for angle (optional, default is 0) + :return: True if target object is within max_distance ahead of the reference object + """ + target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) + norm_target = np.linalg.norm(target_vector) + + # If the vector is too short, we can simply stop here + if norm_target < 0.001: + return True + + if norm_target > max_distance: + return False + + forward_vector = np.array( + [math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) + d_angle = math.degrees(math.acos(np.clip(np.dot(forward_vector, target_vector) / norm_target, -1., 1.))) + + return d_angle_th_low < d_angle < d_angle_th_up + +def compute_distance(location_1, location_2): + """ + Euclidean distance between 3D points + :param location_1, location_2: 3D points + """ + x = location_2.x - location_1.x + y = location_2.y - location_1.y + z = location_2.z - location_1.z + norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps + return norm + + +class CustomGlobalRoutePlanner(GlobalRoutePlanner): + def __init__(self, dao): + super(CustomGlobalRoutePlanner, self).__init__(dao=dao) + + """ + def compute_distance(self, origin, destination): + node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) + distance = 0.0 + for idx in range(len(node_list) - 1): + distance += (super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) + # print ('Distance: ', distance) + return distance + """ + + def compute_direction_velocities(self, origin, velocity, destination): + node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) + + origin_xy = np.array([origin.x, origin.y]) + velocity_xy = np.array([velocity.x, velocity.y]) + + first_node_xy = self._graph.nodes[node_list[1]]['vertex'] + first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) + target_direction_vector = first_node_xy - origin_xy + target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) + + vel_s = np.dot(velocity_xy, target_unit_vector) + + unit_velocity = velocity_xy / (np.linalg.norm(velocity_xy) + 1e-8) + angle = np.arccos(np.clip(np.dot(unit_velocity, target_unit_vector), -1.0, 1.0)) + vel_perp = np.linalg.norm(velocity_xy) * np.sin(angle) + return vel_s, vel_perp + + def compute_distance(self, origin, destination): + node_list = super(CustomGlobalRoutePlanner, self)._path_search(origin=origin, destination=destination) + #print('Node list:', node_list) + first_node_xy = self._graph.nodes[node_list[0]]['vertex'] + #print('Diff:', origin, first_node_xy) + + #distance = 0.0 + distances = [] + distances.append(np.linalg.norm(np.array([origin.x, origin.y, 0.0]) - np.array(first_node_xy))) + + for idx in range(len(node_list) - 1): + distances.append(super(CustomGlobalRoutePlanner, self)._distance_heuristic(node_list[idx], node_list[idx+1])) + #print('Distances:', distances) + #import pdb; pdb.set_trace() + return np.sum(distances) + +class CarlaSyncMode(object): + """ + Context manager to synchronize output from different sensors. Synchronous + mode is enabled as long as we are inside this context + + with CarlaSyncMode(world, sensors) as sync_mode: + while True: + data = sync_mode.tick(timeout=1.0) + + """ + + def __init__(self, world, *sensors, **kwargs): + self.world = world + self.sensors = sensors + self.frame = None + self.delta_seconds = 1.0 / kwargs.get('fps', 20) + self._queues = [] + self._settings = None + + self.start() + + def start(self): + self._settings = self.world.get_settings() + self.frame = self.world.apply_settings(carla.WorldSettings( + no_rendering_mode=False, + synchronous_mode=True, + fixed_delta_seconds=self.delta_seconds)) + + def make_queue(register_event): + q = queue.Queue() + register_event(q.put) + self._queues.append(q) + + make_queue(self.world.on_tick) + for sensor in self.sensors: + make_queue(sensor.listen) + + def tick(self, timeout): + self.frame = self.world.tick() + data = [self._retrieve_data(q, timeout) for q in self._queues] + assert all(x.frame == self.frame for x in data) + return data + + def __exit__(self, *args, **kwargs): + self.world.apply_settings(self._settings) + + def _retrieve_data(self, sensor_queue, timeout): + while True: + data = sensor_queue.get(timeout=timeout) + if data.frame == self.frame: + return data + + +def draw_image(surface, image, blend=False): + array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8")) + array = np.reshape(array, (image.height, image.width, 4)) + array = array[:, :, :3] + array = array[:, :, ::-1] + image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1)) + if blend: + image_surface.set_alpha(100) + surface.blit(image_surface, (0, 0)) + + +def get_font(): + fonts = [x for x in pygame.font.get_fonts()] + default_font = 'ubuntumono' + font = default_font if default_font in fonts else fonts[0] + font = pygame.font.match_font(font) + return pygame.font.Font(font, 14) + + +def should_quit(): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return True + elif event.type == pygame.KEYUP: + if event.key == pygame.K_ESCAPE: + return True + return False + + +def clamp(value, minimum=0.0, maximum=100.0): + return max(minimum, min(value, maximum)) + + +class Sun(object): + def __init__(self, azimuth, altitude): + self.azimuth = azimuth + self.altitude = altitude + self._t = 0.0 + + def tick(self, delta_seconds): + self._t += 0.008 * delta_seconds + self._t %= 2.0 * math.pi + self.azimuth += 0.25 * delta_seconds + self.azimuth %= 360.0 + min_alt, max_alt = [20, 90] + self.altitude = 0.5 * (max_alt + min_alt) + 0.5 * (max_alt - min_alt) * math.cos(self._t) + + def __str__(self): + return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth) + + +class Storm(object): + def __init__(self, precipitation): + self._t = precipitation if precipitation > 0.0 else -50.0 + self._increasing = True + self.clouds = 0.0 + self.rain = 0.0 + self.wetness = 0.0 + self.puddles = 0.0 + self.wind = 0.0 + self.fog = 0.0 + + def tick(self, delta_seconds): + delta = (1.3 if self._increasing else -1.3) * delta_seconds + self._t = clamp(delta + self._t, -250.0, 100.0) + self.clouds = clamp(self._t + 40.0, 0.0, 90.0) + self.clouds = clamp(self._t + 40.0, 0.0, 60.0) + self.rain = clamp(self._t, 0.0, 80.0) + delay = -10.0 if self._increasing else 90.0 + self.puddles = clamp(self._t + delay, 0.0, 85.0) + self.wetness = clamp(self._t * 5, 0.0, 100.0) + self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40 + self.fog = clamp(self._t - 10, 0.0, 30.0) + if self._t == -250.0: + self._increasing = True + if self._t == 100.0: + self._increasing = False + + def __str__(self): + return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) + + +class Weather(object): + def __init__(self, world, changing_weather_speed): + self.world = world + self.reset() + self.weather = world.get_weather() + self.changing_weather_speed = changing_weather_speed + self._sun = Sun(self.weather.sun_azimuth_angle, self.weather.sun_altitude_angle) + self._storm = Storm(self.weather.precipitation) + + def reset(self): + weather_params = carla.WeatherParameters(sun_altitude_angle=90.) + self.world.set_weather(weather_params) + + def tick(self): + self._sun.tick(self.changing_weather_speed) + self._storm.tick(self.changing_weather_speed) + self.weather.cloudiness = self._storm.clouds + self.weather.precipitation = self._storm.rain + self.weather.precipitation_deposits = self._storm.puddles + self.weather.wind_intensity = self._storm.wind + self.weather.fog_density = self._storm.fog + self.weather.wetness = self._storm.wetness + self.weather.sun_azimuth_angle = self._sun.azimuth + self.weather.sun_altitude_angle = self._sun.altitude + self.world.set_weather(self.weather) + + def __str__(self): + return '%s %s' % (self._sun, self._storm) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--vision_size', type=int, default=84) + parser.add_argument('--vision_fov', type=int, default=90) + parser.add_argument('--weather', default=False, action='store_true') + parser.add_argument('--frame_skip', type=int, default=1), + parser.add_argument('--steps', type=int, default=100000) + parser.add_argument('--multiagent', default=False, action='store_true'), + parser.add_argument('--lane', type=int, default=0) + parser.add_argument('--lights', default=False, action='store_true') + args = parser.parse_args() + return args + + +class CarlaEnv(object): + + def __init__(self, args): + self.render_display = False + self.record_display = False + self.record_vision = True + self.record_dir = None #'/nfs/kun1/users/aviralkumar/carla_data/' + self.vision_size = args.vision_size + self.vision_fov = args.vision_fov + self.changing_weather_speed = float(args.weather) + self.frame_skip = args.frame_skip + self.max_episode_steps = args.steps + self.multiagent = args.multiagent + self.start_lane = args.lane + self.follow_traffic_lights = args.lights + if self.record_display: + assert self.render_display + + self.actor_list = [] + + if self.render_display: + pygame.init() + self.render_display = pygame.display.set_mode((800, 600), pygame.HWSURFACE | pygame.DOUBLEBUF) + self.font = get_font() + self.clock = pygame.time.Clock() + + self.client = carla.Client('localhost', 2000) + self.client.set_timeout(2.0) + + self.world = self.client.get_world() + self.map = self.world.get_map() + + ## Define the route planner + self.route_planner_dao = GlobalRoutePlannerDAO(self.map, sampling_resolution=0.1) + self.route_planner = CustomGlobalRoutePlanner(self.route_planner_dao) + + # tests specific to map 4: + if self.start_lane and self.map.name != "Town04": + raise NotImplementedError + + # remove old vehicles and sensors (in case they survived) + self.world.tick() + actor_list = self.world.get_actors() + for vehicle in actor_list.filter("*vehicle*"): + print("Warning: removing old vehicle") + vehicle.destroy() + for sensor in actor_list.filter("*sensor*"): + print("Warning: removing old sensor") + sensor.destroy() + + self.vehicle = None + self.vehicles_list = [] # their ids + self.reset_vehicle() # creates self.vehicle + self.actor_list.append(self.vehicle) + + blueprint_library = self.world.get_blueprint_library() + + if self.render_display: + self.camera_display = self.world.spawn_actor( + blueprint_library.find('sensor.camera.rgb'), + carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), + attach_to=self.vehicle) + self.actor_list.append(self.camera_display) + + bp = blueprint_library.find('sensor.camera.rgb') + bp.set_attribute('image_size_x', str(self.vision_size)) + bp.set_attribute('image_size_y', str(self.vision_size)) + bp.set_attribute('fov', str(self.vision_fov)) + location = carla.Location(x=1.6, z=1.7) + self.camera_vision = self.world.spawn_actor(bp, carla.Transform(location, carla.Rotation(yaw=0.0)), attach_to=self.vehicle) + self.actor_list.append(self.camera_vision) + + if self.record_display or self.record_vision: + if self.record_dir is None: + self.record_dir = "carla-{}-{}x{}-fov{}".format( + self.map.name.lower(), self.vision_size, self.vision_size, self.vision_fov) + if self.frame_skip > 1: + self.record_dir += '-{}'.format(self.frame_skip) + if self.changing_weather_speed > 0.0: + self.record_dir += '-weather' + if self.multiagent: + self.record_dir += '-mutiagent' + if self.follow_traffic_lights: + self.record_dir += '-lights' + self.record_dir += '-{}k'.format(self.max_episode_steps // 1000) + + now = datetime.datetime.now() + self.record_dir += now.strftime("-%Y-%m-%d-%H-%M-%S") + if not os.path.exists(self.record_dir): + os.mkdir(self.record_dir) + + if self.render_display: + self.sync_mode = CarlaSyncMode(self.world, self.camera_display, self.camera_vision, fps=20) + else: + self.sync_mode = CarlaSyncMode(self.world, self.camera_vision, fps=20) + + # weather + self.weather = Weather(self.world, self.changing_weather_speed) + + # dummy variables, to match deep mind control's APIs + low = -1.0 + high = 1.0 + self.action_space = DotMap() + self.action_space.low.min = lambda: low + self.action_space.high.max = lambda: high + self.action_space.shape = [2] + self.observation_space = DotMap() + self.observation_space.shape = (3, self.vision_size, self.vision_size) + self.observation_space.dtype = np.dtype(np.uint8) + self.reward_range = None + self.metadata = None + self.action_space.sample = lambda: np.random.uniform(low=low, high=high, size=self.action_space.shape[0]).astype(np.float32) + + # roaming carla agent + self.agent = None + self.world.tick() + self.reset_init() # creates self.agent + + ## Initialize the route planner + self.route_planner.setup() + + ## Collision detection + self._proximity_threshold = 10.0 + self._traffic_light_threshold = 5.0 + self.actor_list = self.world.get_actors() + for idx in range(len(self.actor_list)): + print (idx, self.actor_list[idx]) + # import ipdb; ipdb.set_trace() + self.vehicle_list = self.actor_list.filter("*vehicle*") + self.lights_list = self.actor_list.filter("*traffic_light*") + self.object_list = self.actor_list.filter("*traffic.*") + + ## Initialize the route planner + self.route_planner.setup() + + ## The map is deterministic so for reward relabelling, we can + ## instantiate the environment object and then query the distance function + ## in the env, which directly uses this map_graph, and we need not save it. + self._map_graph = self.route_planner._graph + + ## This is a dummy for the target location, we can make this an input + ## to the env in RL code. + self.target_location = carla.Location(x=-13.473097, y=134.311234, z=-0.010433) + + ## Now reset the env once + self.reset() + + + def reset_init(self): + self.reset_vehicle() + self.world.tick() + self.reset_other_vehicles() + self.world.tick() + self.agent = RoamingAgent(self.vehicle, follow_traffic_lights=self.follow_traffic_lights) + self.count = 0 + self.ts = int(time.time()) + + def reset(self): + # get obs: + obs, _, _, _ = self.step() + return obs + + def reset_vehicle(self): + + if self.map.name == "Town04": + start_lane = -1 + start_x = 5.0 + vehicle_init_transform = carla.Transform(carla.Location(x=start_x, y=0, z=0.1), carla.Rotation(yaw=-90)) + else: + init_transforms = self.world.get_map().get_spawn_points() + vehicle_init_transform = random.choice(init_transforms) + + # TODO(aviral): start lane not defined for town, also for the town, we may not want to have + # the lane following reward, so it should be okay. + + if self.vehicle is None: # then create the ego vehicle + blueprint_library = self.world.get_blueprint_library() + vehicle_blueprint = blueprint_library.find('vehicle.audi.a2') + self.vehicle = self.world.spawn_actor(vehicle_blueprint, vehicle_init_transform) + + self.vehicle.set_transform(vehicle_init_transform) + self.vehicle.set_velocity(carla.Vector3D()) + self.vehicle.set_angular_velocity(carla.Vector3D()) + + def reset_other_vehicles(self): + if not self.multiagent: + return + + # clear out old vehicles + self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) + self.world.tick() + self.vehicles_list = [] + + traffic_manager = self.client.get_trafficmanager() + traffic_manager.set_global_distance_to_leading_vehicle(2.0) + traffic_manager.set_synchronous_mode(True) + blueprints = self.world.get_blueprint_library().filter('vehicle.*') + blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4] + + num_vehicles = 20 + if self.map.name == "Town04": + road_id = 47 + road_length = 117. + init_transforms = [] + for _ in range(num_vehicles): + lane_id = random.choice([-1, -2, -3, -4]) + vehicle_s = np.random.uniform(road_length) # length of road 47 + init_transforms.append(self.map.get_waypoint_xodr(road_id, lane_id, vehicle_s).transform) + else: + init_transforms = self.world.get_map().get_spawn_points() + init_transforms = np.random.choice(init_transforms, num_vehicles) + + # -------------- + # Spawn vehicles + # -------------- + batch = [] + for transform in init_transforms: + transform.location.z += 0.1 # otherwise can collide with the road it starts on + blueprint = random.choice(blueprints) + if blueprint.has_attribute('color'): + color = random.choice(blueprint.get_attribute('color').recommended_values) + blueprint.set_attribute('color', color) + if blueprint.has_attribute('driver_id'): + driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values) + blueprint.set_attribute('driver_id', driver_id) + blueprint.set_attribute('role_name', 'autopilot') + batch.append(carla.command.SpawnActor(blueprint, transform).then( + carla.command.SetAutopilot(carla.command.FutureActor, True))) + + for response in self.client.apply_batch_sync(batch, False): + self.vehicles_list.append(response.actor_id) + + for response in self.client.apply_batch_sync(batch): + if response.error: + pass + else: + self.vehicles_list.append(response.actor_id) + + traffic_manager.global_percentage_speed_difference(30.0) + + def compute_action(self): + return self.agent.run_step() + + def step(self, action=None, traffic_light_color=""): + rewards = [] + for _ in range(self.frame_skip): # default 1 + next_obs, reward, done, info = self._simulator_step(action, traffic_light_color) + rewards.append(reward) + if done: + break + return next_obs, np.mean(rewards), done, info + + def _is_vehicle_hazard(self, vehicle, vehicle_list): + """ + :param vehicle_list: list of potential obstacle to check + :return: a tuple given by (bool_flag, vehicle), where + - bool_flag is True if there is a vehicle ahead blocking us + and False otherwise + - vehicle is the blocker object itself + """ + + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for target_vehicle in vehicle_list: + # do not account for the ego vehicle + if target_vehicle.id == vehicle.id: + continue + + # if the object is not in our lane it's not an obstacle + target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) + if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(target_vehicle.get_transform(), + vehicle.get_transform(), + self._proximity_threshold/10.0): + return (True, -1.0, target_vehicle) + + return (False, 0.0, None) + + def _is_object_hazard(self, vehicle, object_list): + """ + :param vehicle_list: list of potential obstacle to check + :return: a tuple given by (bool_flag, vehicle), where + - bool_flag is True if there is a vehicle ahead blocking us + and False otherwise + - vehicle is the blocker object itself + """ + + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for target_vehicle in object_list: + # do not account for the ego vehicle + if target_vehicle.id == vehicle.id: + continue + + # if the object is not in our lane it's not an obstacle + target_vehicle_waypoint = self.map.get_waypoint(target_vehicle.get_location()) + if target_vehicle_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + target_vehicle_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(target_vehicle.get_transform(), + vehicle.get_transform(), + self._proximity_threshold/40.0): + return (True, -1.0, target_vehicle) + + return (False, 0.0, None) + + def _is_light_red(self, vehicle): + """ + Method to check if there is a red light affecting us. This version of + the method is compatible with both European and US style traffic lights. + :param lights_list: list containing TrafficLight objects + :return: a tuple given by (bool_flag, traffic_light), where + - bool_flag is True if there is a traffic light in RED + affecting us and False otherwise + - traffic_light is the object itself or None if there is no + red traffic light affecting us + """ + ego_vehicle_location = vehicle.get_location() + ego_vehicle_waypoint = self.map.get_waypoint(ego_vehicle_location) + + for traffic_light in self.lights_list: + object_location = self._get_trafficlight_trigger_location(traffic_light) + object_waypoint = self.map.get_waypoint(object_location) + + if object_waypoint.road_id != ego_vehicle_waypoint.road_id: + continue + + ve_dir = ego_vehicle_waypoint.transform.get_forward_vector() + wp_dir = object_waypoint.transform.get_forward_vector() + dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z + + if dot_ve_wp < 0: + continue + + if is_within_distance_ahead(object_waypoint.transform, + vehicle.get_transform(), + self._traffic_light_threshold): + if traffic_light.state == carla.TrafficLightState.Red: + return (True, -0.1, traffic_light) + + return (False, 0.0, None) + + def _get_trafficlight_trigger_location(self, traffic_light): # pylint: disable=no-self-use + """ + Calculates the yaw of the waypoint that represents the trigger volume of the traffic light + """ + def rotate_point(point, radians): + """ + rotate a given point by a given angle + """ + rotated_x = math.cos(radians) * point.x - math.sin(radians) * point.y + rotated_y = math.sin(radians) * point.x - math.cos(radians) * point.y + + return carla.Vector3D(rotated_x, rotated_y, point.z) + + base_transform = traffic_light.get_transform() + base_rot = base_transform.rotation.yaw + area_loc = base_transform.transform(traffic_light.trigger_volume.location) + area_ext = traffic_light.trigger_volume.extent + + point = rotate_point(carla.Vector3D(0, 0, area_ext.z), math.radians(base_rot)) + point_location = area_loc + carla.Location(x=point.x, y=point.y) + + return carla.Location(point_location.x, point_location.y, point_location.z) + + def _get_collision_reward(self, vehicle): + vehicle_hazard, reward, vehicle_id = self._is_vehicle_hazard(vehicle, self.vehicle_list) + return vehicle_hazard, reward + + def _get_traffic_light_reward(self, vehicle): + traffic_light_hazard, reward, traffic_light_id = self._is_light_red(vehicle) + return traffic_light_hazard, 0.0 + + def _get_object_collided_reward(self, vehicle): + object_hazard, reward, object_id = self._is_object_hazard(vehicle, self.object_list) + return object_hazard, reward + + def goal_reaching_reward(self, vehicle): + # Now we will write goal_reaching_rewards + vehicle_location = vehicle.get_location() + target_location = self.target_location + + # This is the distance computation + """ + dist = self.route_planner.compute_distance(vehicle_location, target_location) + + base_reward = -1.0 * dist + collided_done, collision_reward = self._get_collision_reward(vehicle) + traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) + object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) + total_reward = base_reward + 100 * collision_reward + 100 * traffic_light_reward + 100.0 * object_collided_reward + """ + + vehicle_velocity = vehicle.get_velocity() + dist = self.route_planner.compute_distance(vehicle_location, target_location) + vel_forward, vel_perp = self.route_planner.compute_direction_velocities(vehicle_location, vehicle_velocity, target_location) + #print('[GoalReachReward] VehLoc: %s Target: %s Dist: %s VelF:%s' % (str(vehicle_location), str(target_location), str(dist), str(vel_forward))) + #base_reward = -1.0 * (dist / 100.0) + 5.0 + base_reward = vel_forward + collided_done, collision_reward = self._get_collision_reward(vehicle) + traffic_light_done, traffic_light_reward = self._get_traffic_light_reward(vehicle) + object_collided_done, object_collided_reward = self._get_object_collided_reward(vehicle) + total_reward = base_reward + 100 * collision_reward # + 100 * traffic_light_reward + 100.0 * object_collided_reward + + reward_dict = dict() + reward_dict['collision'] = collision_reward + reward_dict['traffic_light'] = traffic_light_reward + reward_dict['object_collision'] = object_collided_reward + reward_dict['base_reward'] = base_reward + reward_dict['vel_forward'] = vel_forward + reward_dict['vel_perp'] = vel_perp + done_dict = dict() + done_dict['collided_done'] = collided_done + done_dict['traffic_light_done'] = traffic_light_done + done_dict['object_collided_done'] = object_collided_done + return total_reward, reward_dict, done_dict + + def _simulator_step(self, action, traffic_light_color): + + if self.render_display: + if should_quit(): + return + self.clock.tick() + + if action is None: + throttle, steer, brake = 0., 0., 0. + else: + throttle, steer, brake = action.throttle, action.steer, action.brake + # throttle = clamp(throttle, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) + # steer = clamp(steer, minimum=-0.995, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) + # brake = clamp(brake, minimum=0.005, maximum=0.995) + np.random.uniform(low=-0.003, high=0.003) + + vehicle_control = carla.VehicleControl( + throttle=throttle, # [0,1] + steer=steer, # [-1,1] + brake=brake, # [0,1] + hand_brake=False, + reverse=False, + manual_gear_shift=False + ) + self.vehicle.apply_control(vehicle_control) + + # Advance the simulation and wait for the data. + if self.render_display: + snapshot, display_image, vision_image = self.sync_mode.tick(timeout=2.0) + else: + snapshot, vision_image = self.sync_mode.tick(timeout=2.0) + + # Weather evolves + self.weather.tick() + + # Draw the display. + if self.render_display: + draw_image(self.render_display, display_image) + self.render_display.blit(self.font.render('Frame %d' % self.count, True, (255, 255, 255)), (8, 10)) + self.render_display.blit(self.font.render('Control: %5.2f thottle, %5.2f steer, %5.2f brake' % (throttle, steer, brake), True, (255, 255, 255)), (8, 28)) + self.render_display.blit(self.font.render('Traffic light: ' + traffic_light_color, True, (255, 255, 255)), (8, 46)) + self.render_display.blit(self.font.render(str(self.weather), True, (255, 255, 255)), (8, 64)) + pygame.display.flip() + + # Format rl image + bgra = np.array(vision_image.raw_data).reshape(self.vision_size, self.vision_size, 4) # BGRA format + bgr = bgra[:, :, :3] # BGR format (84 x 84 x 3) + rgb = np.flip(bgr, axis=2) # RGB format (84 x 84 x 3) + + reward, reward_dict, done_dict = self.goal_reaching_reward(self.vehicle) + + if self.render_display and self.record_display: + image_name = os.path.join(self.record_dir, "display%08d.jpg" % self.count) + pygame.image.save(self.render_display, image_name) + # # Can animate with: + # ffmpeg -r 20 -pattern_type glob -i 'display*.jpg' carla.mp4 + if self.record_vision: + image_name = os.path.join(self.record_dir, "vision_%d_%08d.png" % (self.ts, self.count)) + im = Image.fromarray(rgb) + # add any eta data you like into the image before we save it: + metadata = PngInfo() + # control + metadata.add_text("control_throttle", str(throttle)) + metadata.add_text("control_steer", str(steer)) + metadata.add_text("control_brake", str(brake)) + metadata.add_text("control_repeat", str(self.frame_skip)) + # acceleration + acceleration = self.vehicle.get_acceleration() + metadata.add_text("acceleration_x", str(acceleration.x)) + metadata.add_text("acceleration_y", str(acceleration.y)) + metadata.add_text("acceleration_z", str(acceleration.z)) + # angular velocity + angular_velocity = self.vehicle.get_angular_velocity() + metadata.add_text("angular_velocity_x", str(angular_velocity.x)) + metadata.add_text("angular_velocity_y", str(angular_velocity.y)) + metadata.add_text("angular_velocity_z", str(angular_velocity.z)) + # location + location = self.vehicle.get_location() + print('Location:', location) + metadata.add_text("location_x", str(location.x)) + metadata.add_text("location_y", str(location.y)) + metadata.add_text("location_z", str(location.z)) + # rotation + rotation = self.vehicle.get_transform().rotation + metadata.add_text("rotation_pitch", str(rotation.pitch)) + metadata.add_text("rotation_yaw", str(rotation.yaw)) + metadata.add_text("rotation_roll", str(rotation.roll)) + forward_vector = rotation.get_forward_vector() + metadata.add_text("forward_vector_x", str(forward_vector.x)) + metadata.add_text("forward_vector_y", str(forward_vector.y)) + metadata.add_text("forward_vector_z", str(forward_vector.z)) + # velocity + velocity = self.vehicle.get_velocity() + metadata.add_text("velocity_x", str(velocity.x)) + metadata.add_text("velocity_y", str(velocity.y)) + metadata.add_text("velocity_z", str(velocity.z)) + # weather + metadata.add_text("weather_cloudiness ", str(self.weather.weather.cloudiness)) + metadata.add_text("weather_precipitation", str(self.weather.weather.precipitation)) + metadata.add_text("weather_precipitation_deposits", str(self.weather.weather.precipitation_deposits)) + metadata.add_text("weather_wind_intensity", str(self.weather.weather.wind_intensity)) + metadata.add_text("weather_fog_density", str(self.weather.weather.fog_density)) + metadata.add_text("weather_wetness", str(self.weather.weather.wetness)) + metadata.add_text("weather_sun_azimuth_angle", str(self.weather.weather.sun_azimuth_angle)) + # settings + metadata.add_text("settings_map", self.map.name) + metadata.add_text("settings_vision_size", str(self.vision_size)) + metadata.add_text("settings_vision_fov", str(self.vision_fov)) + metadata.add_text("settings_changing_weather_speed", str(self.changing_weather_speed)) + metadata.add_text("settings_multiagent", str(self.multiagent)) + # traffic lights + metadata.add_text("traffic_lights_color", "UNLABELED") + metadata.add_text("reward", str(reward)) + + ## Add in reward dict + for key in reward_dict: + metadata.add_text("reward_" + str(key), str(reward_dict[key])) + + for key in done_dict: + metadata.add_text("done_" + str(key), str(done_dict[key])) + + ## Save the target location as well + metadata.add_text('target_location_x', str(self.target_location.x)) + metadata.add_text('target_location_y', str(self.target_location.y)) + metadata.add_text('target_location_z', str(self.target_location.z)) + + im.save(image_name, "PNG", pnginfo=metadata) + + # # To read these images later, you can run something like this: + # from PIL.PngImagePlugin import PngImageFile + # im = PngImageFile("vision00001234.png") + # throttle = float(im.text['throttle']) # range [0, 1] + # steer = float(im.text['steer']) # range [-1, 1] + # brake = float(im.text['brake']) # range [0, 1] + # lights = im.text['lights'] # traffic lights color, [NONE, JUNCTION, RED, YELLOW, GREEN] + self.count += 1 + + next_obs = rgb # 84 x 84 x 3 + # # To inspect images, run: + # import pdb; pdb.set_trace() + # import matplotlib.pyplot as plt + # plt.imshow(next_obs) + # plt.show() + + done = False #self.count >= self.max_episode_steps + if done: + print("Episode success: I've reached the episode horizon ({}).".format(self.max_episode_steps)) + # print ('reward: ', reward) + info = reward_dict + info.update(done_dict) + done = False + for key in done_dict: + done = (done or done_dict[key]) + return next_obs, reward, done, info + + def finish(self): + print('destroying actors.') + for actor in self.actor_list: + actor.destroy() + print('\ndestroying %d vehicles' % len(self.vehicles_list)) + self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list]) + time.sleep(0.5) + pygame.quit() + print('done.') + + +class LocalPlannerModified(LocalPlanner): + + def __del__(self): + pass # otherwise it deletes our vehicle object + + def run_step(self): + return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera + + +class RoamingAgent(Agent): + """ + RoamingAgent implements a basic agent that navigates scenes making random + choices when facing an intersection. + + This agent respects traffic lights and other vehicles. + """ + + def __init__(self, vehicle, follow_traffic_lights=True): + """ + + :param vehicle: actor to apply to local planner logic onto + """ + super(RoamingAgent, self).__init__(vehicle) + self._proximity_threshold = 10.0 # meters + self._state = AgentState.NAVIGATING + self._local_planner = LocalPlannerModified(self._vehicle) + self._follow_traffic_lights = follow_traffic_lights + + def run_step(self): + """ + Execute one step of navigation. + :return: carla.VehicleControl + """ + + # is there an obstacle in front of us? + hazard_detected = False + + # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles + actor_list = self._world.get_actors() + vehicle_list = actor_list.filter("*vehicle*") + lights_list = actor_list.filter("*traffic_light*") + + # check possible obstacles + vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) + if vehicle_state: + + self._state = AgentState.BLOCKED_BY_VEHICLE + hazard_detected = True + + # check for the state of the traffic lights + traffic_light_color = self._is_light_red(lights_list) + if traffic_light_color == 'RED' and self._follow_traffic_lights: + self._state = AgentState.BLOCKED_RED_LIGHT + hazard_detected = True + + if hazard_detected: + control = self.emergency_stop() + else: + self._state = AgentState.NAVIGATING + # standard local planner behavior + control = self._local_planner.run_step() + + return control, traffic_light_color + + # override case class + def _is_light_red_europe_style(self, lights_list): + """ + This method is specialized to check European style traffic lights. + Only suitable for Towns 03 -- 07. + """ + ego_vehicle_location = self._vehicle.get_location() + ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) + + traffic_light_color = "NONE" # default, if no traffic lights are seen + + for traffic_light in lights_list: + object_waypoint = self._map.get_waypoint(traffic_light.get_location()) + if object_waypoint.road_id != ego_vehicle_waypoint.road_id or \ + object_waypoint.lane_id != ego_vehicle_waypoint.lane_id: + continue + + if is_within_distance_ahead(traffic_light.get_transform(), + self._vehicle.get_transform(), + self._proximity_threshold): + if traffic_light.state == carla.TrafficLightState.Red: + return "RED" + elif traffic_light.state == carla.TrafficLightState.Yellow: + traffic_light_color = "YELLOW" + elif traffic_light.state == carla.TrafficLightState.Green: + if traffic_light_color is not "YELLOW": # (more severe) + traffic_light_color = "GREEN" + else: + import pdb; pdb.set_trace() + # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate + + return traffic_light_color + + # override case class + def _is_light_red_us_style(self, lights_list, debug=False): + ego_vehicle_location = self._vehicle.get_location() + ego_vehicle_waypoint = self._map.get_waypoint(ego_vehicle_location) + + traffic_light_color = "NONE" # default, if no traffic lights are seen + + if ego_vehicle_waypoint.is_junction: + # It is too late. Do not block the intersection! Keep going! + return "JUNCTION" + + if self._local_planner.target_waypoint is not None: + if self._local_planner.target_waypoint.is_junction: + min_angle = 180.0 + sel_magnitude = 0.0 + sel_traffic_light = None + for traffic_light in lights_list: + loc = traffic_light.get_location() + magnitude, angle = compute_magnitude_angle(loc, + ego_vehicle_location, + self._vehicle.get_transform().rotation.yaw) + if magnitude < 60.0 and angle < min(25.0, min_angle): + sel_magnitude = magnitude + sel_traffic_light = traffic_light + min_angle = angle + + if sel_traffic_light is not None: + if debug: + print('=== Magnitude = {} | Angle = {} | ID = {}'.format( + sel_magnitude, min_angle, sel_traffic_light.id)) + + if self._last_traffic_light is None: + self._last_traffic_light = sel_traffic_light + + if self._last_traffic_light.state == carla.TrafficLightState.Red: + return "RED" + elif self._last_traffic_light.state == carla.TrafficLightState.Yellow: + traffic_light_color = "YELLOW" + elif self._last_traffic_light.state == carla.TrafficLightState.Green: + if traffic_light_color is not "YELLOW": # (more severe) + traffic_light_color = "GREEN" + else: + import pdb; pdb.set_trace() + # investigate https://carla.readthedocs.io/en/latest/python_api/#carlatrafficlightstate + else: + self._last_traffic_light = None + + return traffic_light_color + + +if __name__ == '__main__': + + # example call: + # ./PythonAPI/util/config.py --map Town01 --delta-seconds 0.05 + # python PythonAPI/carla/agents/navigation/data_collection_agent.py --vision_size 256 --vision_fov 90 --steps 10000 --weather --lights + + args = parse_args() + env = CarlaEnv(args) + + curr_steps = 0 + try: + done = False + while not done: + curr_steps += 1 + action, traffic_light_color = env.compute_action() + next_obs, reward, done, info = env.step(action, traffic_light_color) + print ('Reward: ', reward, 'Done: ', done, 'Location: ', env.vehicle.get_location()) + if done: + # env.reset_init() + # env.reset() + done = False + + if curr_steps % 5000 == 4999: + env.reset_init() + env.reset() + finally: + env.finish() diff --git a/d4rl/d4rl/carla/town_agent.py b/d4rl/d4rl/carla/town_agent.py new file mode 100644 index 0000000..f992d8b --- /dev/null +++ b/d4rl/d4rl/carla/town_agent.py @@ -0,0 +1,150 @@ +# A baseline town agent. +from agents.navigation.agent import Agent, AgentState +import numpy as np +from agents.navigation.local_planner import LocalPlanner + +class RoamingAgent(Agent): + """ + RoamingAgent implements a basic agent that navigates scenes making random + choices when facing an intersection. + + This agent respects traffic lights and other vehicles. + + NOTE: need to re-create after each env reset + """ + + def __init__(self, env): + """ + + :param vehicle: actor to apply to local planner logic onto + """ + vehicle = env.vehicle + follow_traffic_lights = env.follow_traffic_lights + super(RoamingAgent, self).__init__(vehicle) + self._proximity_threshold = 10.0 # meters + self._state = AgentState.NAVIGATING + self._local_planner = LocalPlannerModified(self._vehicle) + self._follow_traffic_lights = follow_traffic_lights + + def compute_action(self): + action, traffic_light = self.run_step() + throttle = action.throttle + brake = action.brake + steer = action.steer + #print('tbsl:', throttle, brake, steer, traffic_light) + if brake == 0.0: + return np.array([throttle, steer]) + else: + return np.array([-brake, steer]) + + def run_step(self): + """ + Execute one step of navigation. + :return: carla.VehicleControl + """ + + # is there an obstacle in front of us? + hazard_detected = False + + # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles + actor_list = self._world.get_actors() + vehicle_list = actor_list.filter("*vehicle*") + lights_list = actor_list.filter("*traffic_light*") + + # check possible obstacles + vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) + if vehicle_state: + + self._state = AgentState.BLOCKED_BY_VEHICLE + hazard_detected = True + + # check for the state of the traffic lights + if hazard_detected: + control = self.emergency_stop() + else: + self._state = AgentState.NAVIGATING + # standard local planner behavior + control = self._local_planner.run_step() + + throttle = control.throttle + brake = control.brake + steer = control.steer + #print('tbsl:', throttle, brake, steer, traffic_light) + if brake == 0.0: + return np.array([throttle, steer]) + else: + return np.array([-brake, steer]) + + +class LocalPlannerModified(LocalPlanner): + + def __del__(self): + pass # otherwise it deletes our vehicle object + + def run_step(self): + return super().run_step(debug=False) # otherwise by default shows waypoints, that interfere with our camera + + +class DummyTownAgent(Agent): + """ + A simple agent for the town driving task. + + If the car is currently facing on a path towards the goal, drive forward. + If the car would start drivign away, apply maximum brakes. + """ + + def __init__(self, env): + """ + :param vehicle: actor to apply to local planner logic onto + """ + self.env = env + super(DummyTownAgent, self).__init__(self.env.vehicle) + self._proximity_threshold = 10.0 # meters + self._state = AgentState.NAVIGATING + self._local_planner = LocalPlannerModified(self._vehicle) + + def compute_action(self): + + hazard_detected = False + # retrieve relevant elements for safe navigation, i.e.: traffic lights and other vehicles + actor_list = self._world.get_actors() + vehicle_list = actor_list.filter("*vehicle*") + lights_list = actor_list.filter("*traffic_light*") + # check possible obstacles + vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) + if vehicle_state: + self._state = AgentState.BLOCKED_BY_VEHICLE + hazard_detected = True + + + + rotation = self.env.vehicle.get_transform().rotation + forward_vector = rotation.get_forward_vector() + origin = self.env.vehicle.get_location() + destination = self.env.target_location + node_list = self.env.route_planner._path_search(origin=origin, destination=destination) + origin_xy = np.array([origin.x, origin.y]) + forward_xy = np.array([forward_vector.x, forward_vector.y]) + first_node_xy = self.env.route_planner._graph.nodes[node_list[0]]['vertex'] + first_node_xy = np.array([first_node_xy[0], first_node_xy[1]]) + target_direction_vector = first_node_xy - origin_xy + target_unit_vector = np.array(target_direction_vector) / np.linalg.norm(target_direction_vector) + vel_s = np.dot(forward_xy, target_unit_vector) + if vel_s < 0: + hazard_detected = True + + + if hazard_detected: + control = self.emergency_stop() + else: + self._state = AgentState.NAVIGATING + # standard local planner behavior + control = self._local_planner.run_step() + throttle = control.throttle + brake = control.brake + steer = control.steer + #print('tbsl:', throttle, brake, steer, traffic_light) + if brake == 0.0: + return np.array([throttle, steer]) + else: + return np.array([-brake, steer]) diff --git a/d4rl/d4rl/flow/__init__.py b/d4rl/d4rl/flow/__init__.py new file mode 100644 index 0000000..ec96e9b --- /dev/null +++ b/d4rl/d4rl/flow/__init__.py @@ -0,0 +1,225 @@ +import gym +import os +from d4rl import offline_env +from gym.envs.registration import register + +from copy import deepcopy + +import flow +import flow.envs +from flow.networks.ring import RingNetwork +from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows +from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams +from flow.networks.ring import ADDITIONAL_NET_PARAMS +from flow.controllers.car_following_models import IDMController +from flow.controllers.routing_controllers import ContinuousRouter +from flow.controllers import SimCarFollowingController, SimLaneChangeController +from flow.controllers import RLController +from flow.core.params import InitialConfig +from flow.core.params import TrafficLightParams +from flow.envs.ring.accel import AccelEnv +from flow.core.params import SumoParams +from flow.utils.registry import make_create_env +from flow.envs import WaveAttenuationPOEnv +from flow.envs import BayBridgeEnv, TrafficLightGridPOEnv + +from d4rl.flow import traffic_light_grid +from d4rl.flow import merge +from d4rl.flow import bottleneck + +def flow_register(flow_params, render=None, **kwargs): + exp_tag = flow_params["exp_tag"] + env_params = flow_params['env'] + net_params = flow_params['net'] + env_class = flow_params['env_name'] + initial_config = flow_params.get('initial', InitialConfig()) + traffic_lights = flow_params.get("tls", TrafficLightParams()) + sim_params = deepcopy(flow_params['sim']) + vehicles = deepcopy(flow_params['veh']) + + sim_params.render = render or sim_params.render + + if isinstance(flow_params["network"], str): + print("""Passing of strings for network will be deprecated. + Please pass the Network instance instead.""") + module = __import__("flow.networks", fromlist=[flow_params["network"]]) + network_class = getattr(module, flow_params["network"]) + else: + network_class = flow_params["network"] + + network = network_class( + name=exp_tag, + vehicles=vehicles, + net_params=net_params, + initial_config=initial_config, + traffic_lights=traffic_lights, + ) + + flow_env = env_class( + env_params= env_params, + sim_params= sim_params, + network= network, + simulator= flow_params['simulator'] + ) + + env = offline_env.OfflineEnvWrapper(flow_env, + **kwargs + ) + return env + + +def ring_env(render='drgb'): + name = "ring" + network_name = RingNetwork + env_name = WaveAttenuationPOEnv + + net_params = NetParams(additional_params=ADDITIONAL_NET_PARAMS) + initial_config = InitialConfig(spacing="uniform", shuffle=False) + + vehicles = VehicleParams() + vehicles.add("human", + acceleration_controller=(IDMController, {}), + routing_controller=(ContinuousRouter, {}), + num_vehicles=21) + vehicles.add(veh_id="rl", + acceleration_controller=(RLController, {}), + routing_controller=(ContinuousRouter, {}), + num_vehicles=1) + + sim_params = SumoParams(sim_step=0.5, render=render, save_render=True) + HORIZON=100 + env_params = EnvParams( + # length of one rollout + horizon=HORIZON, + additional_params={ + # maximum acceleration of autonomous vehicles + "max_accel": 1, + # maximum deceleration of autonomous vehicles + "max_decel": 1, + # bounds on the ranges of ring road lengths the autonomous vehicle + # is trained on + "ring_length": [220, 270], + }, + ) + + + flow_params = dict( + exp_tag=name, + env_name=env_name, + network=network_name, + simulator='traci', + sim=sim_params, + env=env_params, + net=net_params, + veh=vehicles, + initial=initial_config + ) + return flow_params + + +RING_RANDOM_SCORE = -165.22 +RING_EXPERT_SCORE = 24.42 + +register( + id='flow-ring-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=500, + kwargs={ + 'flow_params': ring_env(render=False), + 'dataset_url': None, + 'ref_min_score': RING_RANDOM_SCORE, + 'ref_max_score': RING_EXPERT_SCORE + } +) + + +register( + id='flow-ring-render-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=500, + kwargs={ + 'flow_params': ring_env(render='drgb'), + 'dataset_url': None, + 'ref_min_score': RING_RANDOM_SCORE, + 'ref_max_score': RING_EXPERT_SCORE + } +) + +register( + id='flow-ring-random-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=500, + kwargs={ + 'flow_params': ring_env(render=False), + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5', + 'ref_min_score': RING_RANDOM_SCORE, + 'ref_max_score': RING_EXPERT_SCORE + } +) + + +register( + id='flow-ring-controller-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=500, + kwargs={ + 'flow_params': ring_env(render=False), + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5', + 'ref_min_score': RING_RANDOM_SCORE, + 'ref_max_score': RING_EXPERT_SCORE + } +) + + +MERGE_RANDOM_SCORE = 118.67993 +MERGE_EXPERT_SCORE = 330.03179 + +register( + id='flow-merge-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=750, + kwargs={ + 'flow_params': merge.gen_env(render=False), + 'dataset_url': None, + 'ref_min_score': MERGE_RANDOM_SCORE, + 'ref_max_score': MERGE_EXPERT_SCORE + } +) + + +register( + id='flow-merge-render-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=750, + kwargs={ + 'flow_params': merge.gen_env(render='drgb'), + 'dataset_url': None, + 'ref_min_score': MERGE_RANDOM_SCORE, + 'ref_max_score': MERGE_EXPERT_SCORE + } +) + +register( + id='flow-merge-random-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=750, + kwargs={ + 'flow_params': merge.gen_env(render=False), + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5', + 'ref_min_score': MERGE_RANDOM_SCORE, + 'ref_max_score': MERGE_EXPERT_SCORE + } +) + +register( + id='flow-merge-controller-v0', + entry_point='d4rl.flow:flow_register', + max_episode_steps=750, + kwargs={ + 'flow_params': merge.gen_env(render=False), + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5', + 'ref_min_score': MERGE_RANDOM_SCORE, + 'ref_max_score': MERGE_EXPERT_SCORE + } +) + diff --git a/d4rl/d4rl/flow/bottleneck.py b/d4rl/d4rl/flow/bottleneck.py new file mode 100644 index 0000000..4d22111 --- /dev/null +++ b/d4rl/d4rl/flow/bottleneck.py @@ -0,0 +1,149 @@ +import flow +import flow.envs +from flow.core.params import NetParams, VehicleParams, EnvParams, InFlows +from flow.core.params import SumoLaneChangeParams, SumoCarFollowingParams +from flow.networks.ring import ADDITIONAL_NET_PARAMS +from flow.controllers.routing_controllers import ContinuousRouter +from flow.controllers import SimCarFollowingController, SimLaneChangeController +from flow.controllers import RLController +from flow.core.params import InitialConfig +from flow.core.params import TrafficLightParams +from flow.core.params import SumoParams +from flow.envs import BottleneckDesiredVelocityEnv +from flow.networks import BottleneckNetwork + +def bottleneck(render='drgb'): + # time horizon of a single rollout + HORIZON = 1500 + + SCALING = 1 + NUM_LANES = 4 * SCALING # number of lanes in the widest highway + DISABLE_TB = True + DISABLE_RAMP_METER = True + AV_FRAC = 0.10 + + vehicles = VehicleParams() + vehicles.add( + veh_id="human", + routing_controller=(ContinuousRouter, {}), + car_following_params=SumoCarFollowingParams( + speed_mode=9, + ), + lane_change_params=SumoLaneChangeParams( + lane_change_mode=0, + ), + num_vehicles=1 * SCALING) + vehicles.add( + veh_id="rl", + acceleration_controller=(RLController, {}), + routing_controller=(ContinuousRouter, {}), + car_following_params=SumoCarFollowingParams( + speed_mode=9, + ), + lane_change_params=SumoLaneChangeParams( + lane_change_mode=0, + ), + num_vehicles=1 * SCALING) + + controlled_segments = [("1", 1, False), ("2", 2, True), ("3", 2, True), + ("4", 2, True), ("5", 1, False)] + num_observed_segments = [("1", 1), ("2", 3), ("3", 3), ("4", 3), ("5", 1)] + + additional_env_params = { + "target_velocity": 40, + "disable_tb": True, + "disable_ramp_metering": True, + "controlled_segments": controlled_segments, + "symmetric": False, + "observed_segments": num_observed_segments, + "reset_inflow": False, + "lane_change_duration": 5, + "max_accel": 3, + "max_decel": 3, + "inflow_range": [1200, 2500] + } + + # flow rate + flow_rate = 2500 * SCALING + + # percentage of flow coming out of each lane + inflow = InFlows() + inflow.add( + veh_type="human", + edge="1", + vehs_per_hour=flow_rate * (1 - AV_FRAC), + depart_lane="random", + depart_speed=10) + inflow.add( + veh_type="rl", + edge="1", + vehs_per_hour=flow_rate * AV_FRAC, + depart_lane="random", + depart_speed=10) + + traffic_lights = TrafficLightParams() + if not DISABLE_TB: + traffic_lights.add(node_id="2") + if not DISABLE_RAMP_METER: + traffic_lights.add(node_id="3") + + additional_net_params = {"scaling": SCALING, "speed_limit": 23} + net_params = NetParams( + inflows=inflow, + additional_params=additional_net_params) + + flow_params = dict( + # name of the experiment + exp_tag="bottleneck_0", + + # name of the flow environment the experiment is running on + env_name=BottleneckDesiredVelocityEnv, + + # name of the network class the experiment is running on + network=BottleneckNetwork, + + # simulator that is used by the experiment + simulator='traci', + + # sumo-related parameters (see flow.core.params.SumoParams) + sim=SumoParams( + sim_step=0.5, + render=render, + save_render=True, + print_warnings=False, + restart_instance=True, + ), + + # environment related parameters (see flow.core.params.EnvParams) + env=EnvParams( + warmup_steps=40, + sims_per_step=1, + horizon=HORIZON, + additional_params=additional_env_params, + ), + + # network-related parameters (see flow.core.params.NetParams and the + # network's documentation or ADDITIONAL_NET_PARAMS component) + net=NetParams( + inflows=inflow, + additional_params=additional_net_params, + ), + + # vehicles to be placed in the network at the start of a rollout (see + # flow.core.params.VehicleParams) + veh=vehicles, + + # parameters specifying the positioning of vehicles upon initialization/ + # reset (see flow.core.params.InitialConfig) + initial=InitialConfig( + spacing="uniform", + min_gap=5, + lanes_distribution=float("inf"), + edges_distribution=["2", "3", "4", "5"], + ), + + # traffic lights to be introduced to specific nodes (see + # flow.core.params.TrafficLightParams) + tls=traffic_lights, + ) + return flow_params diff --git a/d4rl/d4rl/flow/merge.py b/d4rl/d4rl/flow/merge.py new file mode 100644 index 0000000..6105822 --- /dev/null +++ b/d4rl/d4rl/flow/merge.py @@ -0,0 +1,119 @@ +"""Open merge example. +Trains a a small percentage of rl vehicles to dissipate shockwaves caused by +on-ramp merge to a single lane open highway network. +""" +from flow.envs import MergePOEnv +from flow.networks import MergeNetwork +from copy import deepcopy +from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \ + InFlows, SumoCarFollowingParams +from flow.networks.merge import ADDITIONAL_NET_PARAMS +from flow.core.params import VehicleParams +from flow.controllers import SimCarFollowingController, RLController + +def gen_env(render='drgb'): + # time horizon of a single rollout + HORIZON = 750 + # inflow rate at the highway + FLOW_RATE = 2000 + # percent of autonomous vehicles + RL_PENETRATION = 0.1 + # num_rl term (see ADDITIONAL_ENV_PARAMs) + NUM_RL = 5 + + # We consider a highway network with an upstream merging lane producing + # shockwaves + additional_net_params = deepcopy(ADDITIONAL_NET_PARAMS) + additional_net_params["merge_lanes"] = 1 + additional_net_params["highway_lanes"] = 1 + additional_net_params["pre_merge_length"] = 500 + + # RL vehicles constitute 5% of the total number of vehicles + vehicles = VehicleParams() + vehicles.add( + veh_id="human", + acceleration_controller=(SimCarFollowingController, {}), + car_following_params=SumoCarFollowingParams( + speed_mode=9, + ), + num_vehicles=5) + vehicles.add( + veh_id="rl", + acceleration_controller=(RLController, {}), + car_following_params=SumoCarFollowingParams( + speed_mode=9, + ), + num_vehicles=0) + + # Vehicles are introduced from both sides of merge, with RL vehicles entering + # from the highway portion as well + inflow = InFlows() + inflow.add( + veh_type="human", + edge="inflow_highway", + vehs_per_hour=(1 - RL_PENETRATION) * FLOW_RATE, + depart_lane="free", + depart_speed=10) + inflow.add( + veh_type="rl", + edge="inflow_highway", + vehs_per_hour=RL_PENETRATION * FLOW_RATE, + depart_lane="free", + depart_speed=10) + inflow.add( + veh_type="human", + edge="inflow_merge", + vehs_per_hour=100, + depart_lane="free", + depart_speed=7.5) + + flow_params = dict( + # name of the experiment + exp_tag="merge_0", + + # name of the flow environment the experiment is running on + env_name=MergePOEnv, + + # name of the network class the experiment is running on + network=MergeNetwork, + + # simulator that is used by the experiment + simulator='traci', + + # sumo-related parameters (see flow.core.params.SumoParams) + sim=SumoParams( + restart_instance=True, + sim_step=0.5, + render=render, + save_render=True + ), + + # environment related parameters (see flow.core.params.EnvParams) + env=EnvParams( + horizon=HORIZON, + sims_per_step=2, + warmup_steps=0, + additional_params={ + "max_accel": 1.5, + "max_decel": 1.5, + "target_velocity": 20, + "num_rl": NUM_RL, + }, + ), + + # network-related parameters (see flow.core.params.NetParams and the + # network's documentation or ADDITIONAL_NET_PARAMS component) + net=NetParams( + inflows=inflow, + additional_params=additional_net_params, + ), + + # vehicles to be placed in the network at the start of a rollout (see + # flow.core.params.VehicleParams) + veh=vehicles, + + # parameters specifying the positioning of vehicles upon initialization/ + # reset (see flow.core.params.InitialConfig) + initial=InitialConfig(), + ) + return flow_params diff --git a/d4rl/d4rl/flow/traffic_light_grid.py b/d4rl/d4rl/flow/traffic_light_grid.py new file mode 100644 index 0000000..dd5bee5 --- /dev/null +++ b/d4rl/d4rl/flow/traffic_light_grid.py @@ -0,0 +1,128 @@ +"""Traffic Light Grid example.""" +from flow.envs import TrafficLightGridBenchmarkEnv +from flow.networks import TrafficLightGridNetwork +from flow.core.params import SumoParams, EnvParams, InitialConfig, NetParams, \ + InFlows, SumoCarFollowingParams +from flow.core.params import VehicleParams +from flow.controllers import SimCarFollowingController, GridRouter + +def gen_env(render='drgb'): + # time horizon of a single rollout + HORIZON = 400 + # inflow rate of vehicles at every edge + EDGE_INFLOW = 300 + # enter speed for departing vehicles + V_ENTER = 30 + # number of row of bidirectional lanes + N_ROWS = 3 + # number of columns of bidirectional lanes + N_COLUMNS = 3 + # length of inner edges in the grid network + INNER_LENGTH = 300 + # length of final edge in route + LONG_LENGTH = 100 + # length of edges that vehicles start on + SHORT_LENGTH = 300 + # number of vehicles originating in the left, right, top, and bottom edges + N_LEFT, N_RIGHT, N_TOP, N_BOTTOM = 1, 1, 1, 1 + + # we place a sufficient number of vehicles to ensure they confirm with the + # total number specified above. We also use a "right_of_way" speed mode to + # support traffic light compliance + vehicles = VehicleParams() + vehicles.add( + veh_id="human", + acceleration_controller=(SimCarFollowingController, {}), + car_following_params=SumoCarFollowingParams( + min_gap=2.5, + max_speed=V_ENTER, + decel=7.5, # avoid collisions at emergency stops + speed_mode="right_of_way", + ), + routing_controller=(GridRouter, {}), + num_vehicles=(N_LEFT + N_RIGHT) * N_COLUMNS + (N_BOTTOM + N_TOP) * N_ROWS) + + # inflows of vehicles are place on all outer edges (listed here) + outer_edges = [] + outer_edges += ["left{}_{}".format(N_ROWS, i) for i in range(N_COLUMNS)] + outer_edges += ["right0_{}".format(i) for i in range(N_ROWS)] + outer_edges += ["bot{}_0".format(i) for i in range(N_ROWS)] + outer_edges += ["top{}_{}".format(i, N_COLUMNS) for i in range(N_ROWS)] + + # equal inflows for each edge (as dictate by the EDGE_INFLOW constant) + inflow = InFlows() + for edge in outer_edges: + inflow.add( + veh_type="human", + edge=edge, + vehs_per_hour=EDGE_INFLOW, + depart_lane="free", + depart_speed=V_ENTER) + + flow_params = dict( + # name of the experiment + exp_tag="grid_0", + + # name of the flow environment the experiment is running on + env_name=TrafficLightGridBenchmarkEnv, + + # name of the network class the experiment is running on + network=TrafficLightGridNetwork, + + # simulator that is used by the experiment + simulator='traci', + + # sumo-related parameters (see flow.core.params.SumoParams) + sim=SumoParams( + restart_instance=True, + sim_step=1, + render=render, + save_render=True, + ), + + # environment related parameters (see flow.core.params.EnvParams) + env=EnvParams( + horizon=HORIZON, + additional_params={ + "target_velocity": 50, + "switch_time": 3, + "num_observed": 2, + "discrete": False, + "tl_type": "actuated" + }, + ), + + # network-related parameters (see flow.core.params.NetParams and the + # network's documentation or ADDITIONAL_NET_PARAMS component) + net=NetParams( + inflows=inflow, + additional_params={ + "speed_limit": V_ENTER + 5, + "grid_array": { + "short_length": SHORT_LENGTH, + "inner_length": INNER_LENGTH, + "long_length": LONG_LENGTH, + "row_num": N_ROWS, + "col_num": N_COLUMNS, + "cars_left": N_LEFT, + "cars_right": N_RIGHT, + "cars_top": N_TOP, + "cars_bot": N_BOTTOM, + }, + "horizontal_lanes": 1, + "vertical_lanes": 1, + }, + ), + + # vehicles to be placed in the network at the start of a rollout (see + # flow.core.params.VehicleParams) + veh=vehicles, + + # parameters specifying the positioning of vehicles upon initialization/ + # reset (see flow.core.params.InitialConfig) + initial=InitialConfig( + spacing='custom', + shuffle=True, + ), + ) + return flow_params diff --git a/d4rl/d4rl/gym_bullet/__init__.py b/d4rl/d4rl/gym_bullet/__init__.py new file mode 100644 index 0000000..addea1c --- /dev/null +++ b/d4rl/d4rl/gym_bullet/__init__.py @@ -0,0 +1,25 @@ +from gym.envs.registration import register +from d4rl.gym_bullet import gym_envs +from d4rl import infos + + +for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']: + register( + id='bullet-%s-v0' % agent, + entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent, + max_episode_steps=1000, + ) + + for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']: + env_name = 'bullet-%s-%s-v0' % (agent, dataset) + register( + id=env_name, + entry_point='d4rl.gym_bullet.gym_envs:get_%s_env' % agent, + max_episode_steps=1000, + kwargs={ + 'ref_min_score': infos.REF_MIN_SCORE[env_name], + 'ref_max_score': infos.REF_MAX_SCORE[env_name], + 'dataset_url': infos.DATASET_URLS[env_name] + } + ) + diff --git a/d4rl/d4rl/gym_bullet/gym_envs.py b/d4rl/d4rl/gym_bullet/gym_envs.py new file mode 100644 index 0000000..266f4e7 --- /dev/null +++ b/d4rl/d4rl/gym_bullet/gym_envs.py @@ -0,0 +1,37 @@ +from .. import offline_env +from pybullet_envs.gym_locomotion_envs import HopperBulletEnv, HalfCheetahBulletEnv, Walker2DBulletEnv, AntBulletEnv +from ..utils.wrappers import NormalizedBoxEnv + +class OfflineAntEnv(AntBulletEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + AntBulletEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineHopperEnv(HopperBulletEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + HopperBulletEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineHalfCheetahEnv(HalfCheetahBulletEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + HalfCheetahBulletEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineWalker2dEnv(Walker2DBulletEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + Walker2DBulletEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + + +def get_ant_env(**kwargs): + return NormalizedBoxEnv(OfflineAntEnv(**kwargs)) + +def get_halfcheetah_env(**kwargs): + return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs)) + +def get_hopper_env(**kwargs): + return NormalizedBoxEnv(OfflineHopperEnv(**kwargs)) + +def get_walker2d_env(**kwargs): + return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs)) + diff --git a/d4rl/d4rl/gym_minigrid/__init__.py b/d4rl/d4rl/gym_minigrid/__init__.py new file mode 100644 index 0000000..8c2ee05 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/__init__.py @@ -0,0 +1,23 @@ +from gym.envs.registration import register + +register( + id='minigrid-fourrooms-v0', + entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv', + max_episode_steps=50, + kwargs={ + 'ref_min_score': 0.01442, + 'ref_max_score': 2.89685, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5' + } +) + +register( + id='minigrid-fourrooms-random-v0', + entry_point='d4rl.gym_minigrid.envs.fourrooms:FourRoomsEnv', + max_episode_steps=50, + kwargs={ + 'ref_min_score': 0.01442, + 'ref_max_score': 2.89685, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5' + } +) diff --git a/d4rl/d4rl/gym_minigrid/envs/__init__.py b/d4rl/d4rl/gym_minigrid/envs/__init__.py new file mode 100644 index 0000000..6fc1721 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/envs/__init__.py @@ -0,0 +1,2 @@ +from d4rl.gym_minigrid.envs.fourrooms import * +from d4rl.gym_minigrid.envs.empty import * diff --git a/d4rl/d4rl/gym_minigrid/envs/empty.py b/d4rl/d4rl/gym_minigrid/envs/empty.py new file mode 100644 index 0000000..0a371df --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/envs/empty.py @@ -0,0 +1,92 @@ +from d4rl.gym_minigrid.minigrid import * +from d4rl.gym_minigrid.register import register + +class EmptyEnv(MiniGridEnv): + """ + Empty grid environment, no obstacles, sparse reward + """ + + def __init__( + self, + size=8, + agent_start_pos=(1,1), + agent_start_dir=0, + ): + self.agent_start_pos = agent_start_pos + self.agent_start_dir = agent_start_dir + + super().__init__( + grid_size=size, + max_steps=4*size*size, + # Set this to True for maximum speed + see_through_walls=True + ) + + def _gen_grid(self, width, height): + # Create an empty grid + self.grid = Grid(width, height) + + # Generate the surrounding walls + self.grid.wall_rect(0, 0, width, height) + + # Place a goal square in the bottom-right corner + self.put_obj(Goal(), width - 2, height - 2) + + # Place the agent + if self.agent_start_pos is not None: + self.agent_pos = self.agent_start_pos + self.agent_dir = self.agent_start_dir + else: + self.place_agent() + + self.mission = "get to the green goal square" + +class EmptyEnv5x5(EmptyEnv): + def __init__(self): + super().__init__(size=5) + +class EmptyRandomEnv5x5(EmptyEnv): + def __init__(self): + super().__init__(size=5, agent_start_pos=None) + +class EmptyEnv6x6(EmptyEnv): + def __init__(self): + super().__init__(size=6) + +class EmptyRandomEnv6x6(EmptyEnv): + def __init__(self): + super().__init__(size=6, agent_start_pos=None) + +class EmptyEnv16x16(EmptyEnv): + def __init__(self): + super().__init__(size=16) + +register( + id='MiniGrid-Empty-5x5-v0', + entry_point='gym_minigrid.envs:EmptyEnv5x5' +) + +register( + id='MiniGrid-Empty-Random-5x5-v0', + entry_point='gym_minigrid.envs:EmptyRandomEnv5x5' +) + +register( + id='MiniGrid-Empty-6x6-v0', + entry_point='gym_minigrid.envs:EmptyEnv6x6' +) + +register( + id='MiniGrid-Empty-Random-6x6-v0', + entry_point='gym_minigrid.envs:EmptyRandomEnv6x6' +) + +register( + id='MiniGrid-Empty-8x8-v0', + entry_point='gym_minigrid.envs:EmptyEnv' +) + +register( + id='MiniGrid-Empty-16x16-v0', + entry_point='gym_minigrid.envs:EmptyEnv16x16' +) diff --git a/d4rl/d4rl/gym_minigrid/envs/fourrooms.py b/d4rl/d4rl/gym_minigrid/envs/fourrooms.py new file mode 100644 index 0000000..7956fdd --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/envs/fourrooms.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from d4rl.gym_minigrid.minigrid import * +from d4rl.gym_minigrid.register import register + + +class FourRoomsEnv(MiniGridEnv): + """ + Classic 4 rooms gridworld environment. + Can specify agent and goal position, if not it set at random. + """ + + def __init__(self, agent_pos=None, goal_pos=None, **kwargs): + self._agent_default_pos = agent_pos + if goal_pos is None: + goal_pos = (12, 12) + self._goal_default_pos = goal_pos + super().__init__(grid_size=19, max_steps=100, **kwargs) + + def get_target(self): + return self._goal_default_pos + + def _gen_grid(self, width, height): + # Create the grid + self.grid = Grid(width, height) + + # Generate the surrounding walls + self.grid.horz_wall(0, 0) + self.grid.horz_wall(0, height - 1) + self.grid.vert_wall(0, 0) + self.grid.vert_wall(width - 1, 0) + + room_w = width // 2 + room_h = height // 2 + + # For each row of rooms + for j in range(0, 2): + + # For each column + for i in range(0, 2): + xL = i * room_w + yT = j * room_h + xR = xL + room_w + yB = yT + room_h + + # Bottom wall and door + if i + 1 < 2: + self.grid.vert_wall(xR, yT, room_h) + pos = (xR, self._rand_int(yT + 1, yB)) + self.grid.set(*pos, None) + + # Bottom wall and door + if j + 1 < 2: + self.grid.horz_wall(xL, yB, room_w) + pos = (self._rand_int(xL + 1, xR), yB) + self.grid.set(*pos, None) + + # Randomize the player start position and orientation + if self._agent_default_pos is not None: + self.agent_pos = self._agent_default_pos + self.grid.set(*self._agent_default_pos, None) + self.agent_dir = self._rand_int(0, 4) # assuming random start direction + else: + self.place_agent() + + if self._goal_default_pos is not None: + goal = Goal() + self.put_obj(goal, *self._goal_default_pos) + goal.init_pos, goal.cur_pos = self._goal_default_pos + else: + self.place_obj(Goal()) + + self.mission = 'Reach the goal' + + def step(self, action): + obs, reward, done, info = MiniGridEnv.step(self, action) + return obs, reward, done, info + +register( + id='MiniGrid-FourRooms-v0', + entry_point='gym_minigrid.envs:FourRoomsEnv' +) diff --git a/d4rl/d4rl/gym_minigrid/fourroom_controller.py b/d4rl/d4rl/gym_minigrid/fourroom_controller.py new file mode 100644 index 0000000..97db3ab --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/fourroom_controller.py @@ -0,0 +1,84 @@ +import numpy as np +import random + +from d4rl.pointmaze import q_iteration +from d4rl.pointmaze.gridcraft import grid_env +from d4rl.pointmaze.gridcraft import grid_spec + +MAZE = \ +"###################\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOOOOOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"####O#########O####\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"#OOOOOOOOOOOOOOOOO#\\"+\ +"#OOOOOOOO#OOOOOOOO#\\"+\ +"###################\\" + + +# NLUDR -> RDLU +TRANSLATE_DIRECTION = { + 0: None, + 1: 3,#3, + 2: 1,#1, + 3: 2,#2, + 4: 0,#0, +} + +RIGHT = 1 +LEFT = 0 +FORWARD = 2 + +class FourRoomController(object): + def __init__(self): + self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE)) + self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY))) + + def sample_target(self): + return random.choice(self.reset_locations) + + def set_target(self, target): + self.target = target + self.env.gs[target] = grid_spec.REWARD + self.q_values = q_iteration.q_iteration(env=self.env, num_itrs=32, discount=0.99) + self.env.gs[target] = grid_spec.EMPTY + + def get_action(self, pos, orientation): + if tuple(pos) == tuple(self.target): + done = True + else: + done = False + env_pos_idx = self.env.gs.xy_to_idx(pos) + qvalues = self.q_values[env_pos_idx] + direction = TRANSLATE_DIRECTION[np.argmax(qvalues)] + #tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues)) + #tgt_pos = self.env.gs.idx_to_xy(tgt_pos) + #print('\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos) + #infos = {} + #infos['tgt_pos'] = tgt_pos + if orientation == direction or direction == None: + return FORWARD, done + else: + return get_turn(orientation, direction), done + +#RDLU +TURN_DIRS = [ + [None, RIGHT, RIGHT, LEFT], #R + [LEFT, None, RIGHT, RIGHT], #D + [RIGHT, LEFT, None, RIGHT], #L + [RIGHT, RIGHT, LEFT, None], #U +] + +def get_turn(ori, tgt_ori): + return TURN_DIRS[ori][tgt_ori] diff --git a/d4rl/d4rl/gym_minigrid/minigrid.py b/d4rl/d4rl/gym_minigrid/minigrid.py new file mode 100644 index 0000000..b100fd0 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/minigrid.py @@ -0,0 +1,1289 @@ +import math +import gym +from enum import IntEnum +import numpy as np +from gym import error, spaces, utils +from gym.utils import seeding +from d4rl.gym_minigrid.rendering import * +from d4rl import offline_env + +# Size in pixels of a tile in the full-scale human view +TILE_PIXELS = 32 + +# Map of color names to RGB values +COLORS = { + 'red' : np.array([255, 0, 0]), + 'green' : np.array([0, 255, 0]), + 'blue' : np.array([0, 0, 255]), + 'purple': np.array([112, 39, 195]), + 'yellow': np.array([255, 255, 0]), + 'grey' : np.array([100, 100, 100]) +} + +COLOR_NAMES = sorted(list(COLORS.keys())) + +# Used to map colors to integers +COLOR_TO_IDX = { + 'red' : 0, + 'green' : 1, + 'blue' : 2, + 'purple': 3, + 'yellow': 4, + 'grey' : 5 +} + +IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) + +# Map of object type to integers +OBJECT_TO_IDX = { + 'unseen' : 0, + 'empty' : 1, + 'wall' : 2, + 'floor' : 3, + 'door' : 4, + 'key' : 5, + 'ball' : 6, + 'box' : 7, + 'goal' : 8, + 'lava' : 9, + 'agent' : 10, +} + +IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) + +# Map of state names to integers +STATE_TO_IDX = { + 'open' : 0, + 'closed': 1, + 'locked': 2, +} + +# Map of agent direction indices to vectors +DIR_TO_VEC = [ + # Pointing right (positive X) + np.array((1, 0)), + # Down (positive Y) + np.array((0, 1)), + # Pointing left (negative X) + np.array((-1, 0)), + # Up (negative Y) + np.array((0, -1)), +] + +class WorldObj: + """ + Base class for grid world objects + """ + + def __init__(self, type, color): + assert type in OBJECT_TO_IDX, type + assert color in COLOR_TO_IDX, color + self.type = type + self.color = color + self.contains = None + + # Initial position of the object + self.init_pos = None + + # Current position of the object + self.cur_pos = None + + def can_overlap(self): + """Can the agent overlap with this?""" + return False + + def can_pickup(self): + """Can the agent pick this up?""" + return False + + def can_contain(self): + """Can this contain another object?""" + return False + + def see_behind(self): + """Can the agent see behind this object?""" + return True + + def toggle(self, env, pos): + """Method to trigger/toggle an action this object performs""" + return False + + def encode(self): + """Encode the a description of this object as a 3-tuple of integers""" + return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0) + + @staticmethod + def decode(type_idx, color_idx, state): + """Create an object from a 3-tuple state description""" + + obj_type = IDX_TO_OBJECT[type_idx] + color = IDX_TO_COLOR[color_idx] + + if obj_type == 'empty' or obj_type == 'unseen': + return None + + # State, 0: open, 1: closed, 2: locked + is_open = state == 0 + is_locked = state == 2 + + if obj_type == 'wall': + v = Wall(color) + elif obj_type == 'floor': + v = Floor(color) + elif obj_type == 'ball': + v = Ball(color) + elif obj_type == 'key': + v = Key(color) + elif obj_type == 'box': + v = Box(color) + elif obj_type == 'door': + v = Door(color, is_open, is_locked) + elif obj_type == 'goal': + v = Goal() + elif obj_type == 'lava': + v = Lava() + else: + assert False, "unknown object type in decode '%s'" % objType + + return v + + def render(self, r): + """Draw this object with the given renderer""" + raise NotImplementedError + +class Goal(WorldObj): + def __init__(self): + super().__init__('goal', 'green') + + def can_overlap(self): + return True + + def render(self, img): + fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color]) + +class Floor(WorldObj): + """ + Colored floor tile the agent can walk over + """ + + def __init__(self, color='blue'): + super().__init__('floor', color) + + def can_overlap(self): + return True + + def render(self, r): + # Give the floor a pale color + c = COLORS[self.color] + r.setLineColor(100, 100, 100, 0) + r.setColor(*c/2) + r.drawPolygon([ + (1 , TILE_PIXELS), + (TILE_PIXELS, TILE_PIXELS), + (TILE_PIXELS, 1), + (1 , 1) + ]) + +class Lava(WorldObj): + def __init__(self): + super().__init__('lava', 'red') + + def can_overlap(self): + return True + + def render(self, img): + c = (255, 128, 0) + + # Background color + fill_coords(img, point_in_rect(0, 1, 0, 1), c) + + # Little waves + for i in range(3): + ylo = 0.3 + 0.2 * i + yhi = 0.4 + 0.2 * i + fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0,0,0)) + fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0,0,0)) + fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0,0,0)) + fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0,0,0)) + +class Wall(WorldObj): + def __init__(self, color='grey'): + super().__init__('wall', color) + + def see_behind(self): + return False + + def render(self, img): + fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color]) + +class Door(WorldObj): + def __init__(self, color, is_open=False, is_locked=False): + super().__init__('door', color) + self.is_open = is_open + self.is_locked = is_locked + + def can_overlap(self): + """The agent can only walk over this cell when the door is open""" + return self.is_open + + def see_behind(self): + return self.is_open + + def toggle(self, env, pos): + # If the player has the right key to open the door + if self.is_locked: + if isinstance(env.carrying, Key) and env.carrying.color == self.color: + self.is_locked = False + self.is_open = True + return True + return False + + self.is_open = not self.is_open + return True + + def encode(self): + """Encode the a description of this object as a 3-tuple of integers""" + + # State, 0: open, 1: closed, 2: locked + if self.is_open: + state = 0 + elif self.is_locked: + state = 2 + elif not self.is_open: + state = 1 + + return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state) + + def render(self, img): + c = COLORS[self.color] + + if self.is_open: + fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c) + fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0,0,0)) + return + + # Door frame and door + if self.is_locked: + fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) + fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c)) + + # Draw key slot + fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c) + else: + fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c) + fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0,0,0)) + fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c) + fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0,0,0)) + + # Draw door handle + fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c) + +class Key(WorldObj): + def __init__(self, color='blue'): + super(Key, self).__init__('key', color) + + def can_pickup(self): + return True + + def render(self, img): + c = COLORS[self.color] + + # Vertical quad + fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c) + + # Teeth + fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c) + fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c) + + # Ring + fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c) + fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0,0,0)) + +class Ball(WorldObj): + def __init__(self, color='blue'): + super(Ball, self).__init__('ball', color) + + def can_pickup(self): + return True + + def render(self, img): + fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color]) + +class Box(WorldObj): + def __init__(self, color, contains=None): + super(Box, self).__init__('box', color) + self.contains = contains + + def can_pickup(self): + return True + + def render(self, img): + c = COLORS[self.color] + + # Outline + fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c) + fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0,0,0)) + + # Horizontal slit + fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c) + + def toggle(self, env, pos): + # Replace the box by its contents + env.grid.set(*pos, self.contains) + return True + +class Grid: + """ + Represent a grid and operations on it + """ + + # Static cache of pre-renderer tiles + tile_cache = {} + + def __init__(self, width, height): + assert width >= 3 + assert height >= 3 + + self.width = width + self.height = height + + self.grid = [None] * width * height + + def __contains__(self, key): + if isinstance(key, WorldObj): + for e in self.grid: + if e is key: + return True + elif isinstance(key, tuple): + for e in self.grid: + if e is None: + continue + if (e.color, e.type) == key: + return True + if key[0] is None and key[1] == e.type: + return True + return False + + def __eq__(self, other): + grid1 = self.encode() + grid2 = other.encode() + return np.array_equal(grid2, grid1) + + def __ne__(self, other): + return not self == other + + def copy(self): + from copy import deepcopy + return deepcopy(self) + + def set(self, i, j, v): + assert i >= 0 and i < self.width + assert j >= 0 and j < self.height + self.grid[j * self.width + i] = v + + def get(self, i, j): + assert i >= 0 and i < self.width + assert j >= 0 and j < self.height + return self.grid[j * self.width + i] + + def horz_wall(self, x, y, length=None, obj_type=Wall): + if length is None: + length = self.width - x + for i in range(0, length): + self.set(x + i, y, obj_type()) + + def vert_wall(self, x, y, length=None, obj_type=Wall): + if length is None: + length = self.height - y + for j in range(0, length): + self.set(x, y + j, obj_type()) + + def wall_rect(self, x, y, w, h): + self.horz_wall(x, y, w) + self.horz_wall(x, y+h-1, w) + self.vert_wall(x, y, h) + self.vert_wall(x+w-1, y, h) + + def rotate_left(self): + """ + Rotate the grid to the left (counter-clockwise) + """ + + grid = Grid(self.height, self.width) + + for i in range(self.width): + for j in range(self.height): + v = self.get(i, j) + grid.set(j, grid.height - 1 - i, v) + + return grid + + def slice(self, topX, topY, width, height): + """ + Get a subset of the grid + """ + + grid = Grid(width, height) + + for j in range(0, height): + for i in range(0, width): + x = topX + i + y = topY + j + + if x >= 0 and x < self.width and \ + y >= 0 and y < self.height: + v = self.get(x, y) + else: + v = Wall() + + grid.set(i, j, v) + + return grid + + @classmethod + def render_tile( + cls, + obj, + agent_dir=None, + highlight=False, + tile_size=TILE_PIXELS, + subdivs=3 + ): + """ + Render a tile and cache the result + """ + + # Hash map lookup key for the cache + key = (agent_dir, highlight, tile_size) + key = obj.encode() + key if obj else key + + if key in cls.tile_cache: + return cls.tile_cache[key] + + img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8) + + # Draw the grid lines (top and left edges) + fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100)) + fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100)) + + if obj != None: + obj.render(img) + + # Overlay the agent on top + if agent_dir is not None: + tri_fn = point_in_triangle( + (0.12, 0.19), + (0.87, 0.50), + (0.12, 0.81), + ) + + # Rotate the agent based on its direction + tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir) + fill_coords(img, tri_fn, (255, 0, 0)) + + # Highlight the cell if needed + if highlight: + highlight_img(img) + + # Downsample the image to perform supersampling/anti-aliasing + img = downsample(img, subdivs) + + # Cache the rendered tile + cls.tile_cache[key] = img + + return img + + def render( + self, + tile_size, + agent_pos=None, + agent_dir=None, + highlight_mask=None + ): + """ + Render this grid at a given scale + :param r: target renderer object + :param tile_size: tile size in pixels + """ + + if highlight_mask is None: + highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) + + # Compute the total grid size + width_px = self.width * tile_size + height_px = self.height * tile_size + + img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) + + # Render the grid + for j in range(0, self.height): + for i in range(0, self.width): + cell = self.get(i, j) + + agent_here = np.array_equal(agent_pos, (i, j)) + tile_img = Grid.render_tile( + cell, + agent_dir=agent_dir if agent_here else None, + highlight=highlight_mask[i, j], + tile_size=tile_size + ) + + ymin = j * tile_size + ymax = (j+1) * tile_size + xmin = i * tile_size + xmax = (i+1) * tile_size + img[ymin:ymax, xmin:xmax, :] = tile_img + + return img + + def encode(self, vis_mask=None): + """ + Produce a compact numpy encoding of the grid + """ + + if vis_mask is None: + vis_mask = np.ones((self.width, self.height), dtype=bool) + + array = np.zeros((self.width, self.height, 3), dtype='uint8') + + for i in range(self.width): + for j in range(self.height): + if vis_mask[i, j]: + v = self.get(i, j) + + if v is None: + array[i, j, 0] = OBJECT_TO_IDX['empty'] + array[i, j, 1] = 0 + array[i, j, 2] = 0 + + else: + array[i, j, :] = v.encode() + + return array + + @staticmethod + def decode(array): + """ + Decode an array grid encoding back into a grid + """ + + width, height, channels = array.shape + assert channels == 3 + + vis_mask = np.ones(shape=(width, height), dtype=np.bool) + + grid = Grid(width, height) + for i in range(width): + for j in range(height): + type_idx, color_idx, state = array[i, j] + v = WorldObj.decode(type_idx, color_idx, state) + grid.set(i, j, v) + vis_mask[i, j] = (type_idx != OBJECT_TO_IDX['unseen']) + + return grid, vis_mask + + def process_vis(grid, agent_pos): + mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool) + + mask[agent_pos[0], agent_pos[1]] = True + + for j in reversed(range(0, grid.height)): + for i in range(0, grid.width-1): + if not mask[i, j]: + continue + + cell = grid.get(i, j) + if cell and not cell.see_behind(): + continue + + mask[i+1, j] = True + if j > 0: + mask[i+1, j-1] = True + mask[i, j-1] = True + + for i in reversed(range(1, grid.width)): + if not mask[i, j]: + continue + + cell = grid.get(i, j) + if cell and not cell.see_behind(): + continue + + mask[i-1, j] = True + if j > 0: + mask[i-1, j-1] = True + mask[i, j-1] = True + + for j in range(0, grid.height): + for i in range(0, grid.width): + if not mask[i, j]: + grid.set(i, j, None) + + return mask + +class MiniGridEnv(offline_env.OfflineEnv): + """ + 2D grid world game environment + """ + + metadata = { + 'render.modes': ['human', 'rgb_array'], + 'video.frames_per_second' : 10 + } + + # Enumeration of possible actions + class Actions(IntEnum): + # Turn left, turn right, move forward + left = 0 + right = 1 + forward = 2 + + # Pick up an object + pickup = 3 + # Drop an object + drop = 4 + # Toggle/activate an object + toggle = 5 + + # Done completing task + done = 6 + + def __init__( + self, + grid_size=None, + width=None, + height=None, + max_steps=100, + see_through_walls=False, + seed=1337, + agent_view_size=7, + **kwargs + ): + offline_env.OfflineEnv.__init__(self, **kwargs) + # Can't set both grid_size and width/height + if grid_size: + assert width == None and height == None + width = grid_size + height = grid_size + + # Action enumeration for this environment + self.actions = MiniGridEnv.Actions + + # Actions are discrete integer values + self.action_space = spaces.Discrete(len(self.actions)) + + # Number of cells (width and height) in the agent view + self.agent_view_size = agent_view_size + + # Observations are dictionaries containing an + # encoding of the grid and a textual 'mission' string + self.observation_space = spaces.Box( + low=0, + high=255, + shape=(self.agent_view_size, self.agent_view_size, 3), + dtype='uint8' + ) + self.observation_space = spaces.Dict({ + 'image': self.observation_space + }) + + # Range of possible rewards + self.reward_range = (0, 1) + + # Window to use for human rendering mode + self.window = None + + # Environment configuration + self.width = width + self.height = height + self.max_steps = max_steps + self.see_through_walls = see_through_walls + + # Current position and direction of the agent + self.agent_pos = None + self.agent_dir = None + + # Initialize the RNG + self.seed(seed=seed) + + # Initialize the state + self.reset() + + def reset(self): + # Current position and direction of the agent + self.agent_pos = None + self.agent_dir = None + + # Generate a new random grid at the start of each episode + # To keep the same grid for each episode, call env.seed() with + # the same seed before calling env.reset() + self._gen_grid(self.width, self.height) + + # These fields should be defined by _gen_grid + assert self.agent_pos is not None + assert self.agent_dir is not None + + # Check that the agent doesn't overlap with an object + start_cell = self.grid.get(*self.agent_pos) + assert start_cell is None or start_cell.can_overlap() + + # Item picked up, being carried, initially nothing + self.carrying = None + + # Step count since episode start + self.step_count = 0 + + # Return first observation + obs = self.gen_obs() + return obs + + def seed(self, seed=1337): + # Seed the random number generator + self.np_random, _ = seeding.np_random(seed) + return [seed] + + @property + def steps_remaining(self): + return self.max_steps - self.step_count + + def __str__(self): + """ + Produce a pretty string of the environment's grid along with the agent. + A grid cell is represented by 2-character string, the first one for + the object and the second one for the color. + """ + + # Map of object types to short string + OBJECT_TO_STR = { + 'wall' : 'W', + 'floor' : 'F', + 'door' : 'D', + 'key' : 'K', + 'ball' : 'A', + 'box' : 'B', + 'goal' : 'G', + 'lava' : 'V', + } + + # Short string for opened door + OPENDED_DOOR_IDS = '_' + + # Map agent's direction to short string + AGENT_DIR_TO_STR = { + 0: '>', + 1: 'V', + 2: '<', + 3: '^' + } + + str = '' + + for j in range(self.grid.height): + + for i in range(self.grid.width): + if i == self.agent_pos[0] and j == self.agent_pos[1]: + str += 2 * AGENT_DIR_TO_STR[self.agent_dir] + continue + + c = self.grid.get(i, j) + + if c == None: + str += ' ' + continue + + if c.type == 'door': + if c.is_open: + str += '__' + elif c.is_locked: + str += 'L' + c.color[0].upper() + else: + str += 'D' + c.color[0].upper() + continue + + str += OBJECT_TO_STR[c.type] + c.color[0].upper() + + if j < self.grid.height - 1: + str += '\n' + + return str + + def _gen_grid(self, width, height): + assert False, "_gen_grid needs to be implemented by each environment" + + def _reward(self): + """ + Compute the reward to be given upon success + """ + + return 1 - 0.9 * (self.step_count / self.max_steps) + + def _rand_int(self, low, high): + """ + Generate random integer in [low,high[ + """ + + return self.np_random.randint(low, high) + + def _rand_float(self, low, high): + """ + Generate random float in [low,high[ + """ + + return self.np_random.uniform(low, high) + + def _rand_bool(self): + """ + Generate random boolean value + """ + + return (self.np_random.randint(0, 2) == 0) + + def _rand_elem(self, iterable): + """ + Pick a random element in a list + """ + + lst = list(iterable) + idx = self._rand_int(0, len(lst)) + return lst[idx] + + def _rand_subset(self, iterable, num_elems): + """ + Sample a random subset of distinct elements of a list + """ + + lst = list(iterable) + assert num_elems <= len(lst) + + out = [] + + while len(out) < num_elems: + elem = self._rand_elem(lst) + lst.remove(elem) + out.append(elem) + + return out + + def _rand_color(self): + """ + Generate a random color name (string) + """ + + return self._rand_elem(COLOR_NAMES) + + def _rand_pos(self, xLow, xHigh, yLow, yHigh): + """ + Generate a random (x,y) position tuple + """ + + return ( + self.np_random.randint(xLow, xHigh), + self.np_random.randint(yLow, yHigh) + ) + + def place_obj(self, + obj, + top=None, + size=None, + reject_fn=None, + max_tries=math.inf + ): + """ + Place an object at an empty position in the grid + + :param top: top-left position of the rectangle where to place + :param size: size of the rectangle where to place + :param reject_fn: function to filter out potential positions + """ + + if top is None: + top = (0, 0) + else: + top = (max(top[0], 0), max(top[1], 0)) + + if size is None: + size = (self.grid.width, self.grid.height) + + num_tries = 0 + + while True: + # This is to handle with rare cases where rejection sampling + # gets stuck in an infinite loop + if num_tries > max_tries: + raise RecursionError('rejection sampling failed in place_obj') + + num_tries += 1 + + pos = np.array(( + self._rand_int(top[0], min(top[0] + size[0], self.grid.width)), + self._rand_int(top[1], min(top[1] + size[1], self.grid.height)) + )) + + # Don't place the object on top of another object + if self.grid.get(*pos) != None: + continue + + # Don't place the object where the agent is + if np.array_equal(pos, self.agent_pos): + continue + + # Check if there is a filtering criterion + if reject_fn and reject_fn(self, pos): + continue + + break + + self.grid.set(*pos, obj) + + if obj is not None: + obj.init_pos = pos + obj.cur_pos = pos + + return pos + + def put_obj(self, obj, i, j): + """ + Put an object at a specific position in the grid + """ + + self.grid.set(i, j, obj) + obj.init_pos = (i, j) + obj.cur_pos = (i, j) + + def place_agent( + self, + top=None, + size=None, + rand_dir=True, + max_tries=math.inf + ): + """ + Set the agent's starting point at an empty position in the grid + """ + + self.agent_pos = None + pos = self.place_obj(None, top, size, max_tries=max_tries) + self.agent_pos = pos + + if rand_dir: + self.agent_dir = self._rand_int(0, 4) + + return pos + + @property + def dir_vec(self): + """ + Get the direction vector for the agent, pointing in the direction + of forward movement. + """ + + assert self.agent_dir >= 0 and self.agent_dir < 4 + return DIR_TO_VEC[self.agent_dir] + + @property + def right_vec(self): + """ + Get the vector pointing to the right of the agent. + """ + + dx, dy = self.dir_vec + return np.array((-dy, dx)) + + @property + def front_pos(self): + """ + Get the position of the cell that is right in front of the agent + """ + + return self.agent_pos + self.dir_vec + + def get_view_coords(self, i, j): + """ + Translate and rotate absolute grid coordinates (i, j) into the + agent's partially observable view (sub-grid). Note that the resulting + coordinates may be negative or outside of the agent's view size. + """ + + ax, ay = self.agent_pos + dx, dy = self.dir_vec + rx, ry = self.right_vec + + # Compute the absolute coordinates of the top-left view corner + sz = self.agent_view_size + hs = self.agent_view_size // 2 + tx = ax + (dx * (sz-1)) - (rx * hs) + ty = ay + (dy * (sz-1)) - (ry * hs) + + lx = i - tx + ly = j - ty + + # Project the coordinates of the object relative to the top-left + # corner onto the agent's own coordinate system + vx = (rx*lx + ry*ly) + vy = -(dx*lx + dy*ly) + + return vx, vy + + def get_view_exts(self): + """ + Get the extents of the square set of tiles visible to the agent + Note: the bottom extent indices are not included in the set + """ + + # Facing right + if self.agent_dir == 0: + topX = self.agent_pos[0] + topY = self.agent_pos[1] - self.agent_view_size // 2 + # Facing down + elif self.agent_dir == 1: + topX = self.agent_pos[0] - self.agent_view_size // 2 + topY = self.agent_pos[1] + # Facing left + elif self.agent_dir == 2: + topX = self.agent_pos[0] - self.agent_view_size + 1 + topY = self.agent_pos[1] - self.agent_view_size // 2 + # Facing up + elif self.agent_dir == 3: + topX = self.agent_pos[0] - self.agent_view_size // 2 + topY = self.agent_pos[1] - self.agent_view_size + 1 + else: + assert False, "invalid agent direction" + + botX = topX + self.agent_view_size + botY = topY + self.agent_view_size + + return (topX, topY, botX, botY) + + def relative_coords(self, x, y): + """ + Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates + """ + + vx, vy = self.get_view_coords(x, y) + + if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size: + return None + + return vx, vy + + def in_view(self, x, y): + """ + check if a grid position is visible to the agent + """ + + return self.relative_coords(x, y) is not None + + def agent_sees(self, x, y): + """ + Check if a non-empty grid position is visible to the agent + """ + + coordinates = self.relative_coords(x, y) + if coordinates is None: + return False + vx, vy = coordinates + + obs = self.gen_obs() + obs_grid, _ = Grid.decode(obs['image']) + obs_cell = obs_grid.get(vx, vy) + world_cell = self.grid.get(x, y) + + return obs_cell is not None and obs_cell.type == world_cell.type + + def step(self, action): + self.step_count += 1 + + reward = 0 + done = False + + # Get the position in front of the agent + fwd_pos = self.front_pos + + # Get the contents of the cell in front of the agent + fwd_cell = self.grid.get(*fwd_pos) + + # Rotate left + if action == self.actions.left: + self.agent_dir -= 1 + if self.agent_dir < 0: + self.agent_dir += 4 + + # Rotate right + elif action == self.actions.right: + self.agent_dir = (self.agent_dir + 1) % 4 + + # Move forward + elif action == self.actions.forward: + if fwd_cell == None or fwd_cell.can_overlap(): + self.agent_pos = fwd_pos + if fwd_cell != None and fwd_cell.type == 'goal': + done = True + reward = self._reward() + if fwd_cell != None and fwd_cell.type == 'lava': + done = True + + # Pick up an object + elif action == self.actions.pickup: + if fwd_cell and fwd_cell.can_pickup(): + if self.carrying is None: + self.carrying = fwd_cell + self.carrying.cur_pos = np.array([-1, -1]) + self.grid.set(*fwd_pos, None) + + # Drop an object + elif action == self.actions.drop: + if not fwd_cell and self.carrying: + self.grid.set(*fwd_pos, self.carrying) + self.carrying.cur_pos = fwd_pos + self.carrying = None + + # Toggle/activate an object + elif action == self.actions.toggle: + if fwd_cell: + fwd_cell.toggle(self, fwd_pos) + + # Done action (not used by default) + elif action == self.actions.done: + pass + + else: + assert False, "unknown action" + + if self.step_count >= self.max_steps: + done = True + + obs = self.gen_obs() + + return obs, reward, done, {} + + def gen_obs_grid(self): + """ + Generate the sub-grid observed by the agent. + This method also outputs a visibility mask telling us which grid + cells the agent can actually see. + """ + + topX, topY, botX, botY = self.get_view_exts() + + grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size) + + for i in range(self.agent_dir + 1): + grid = grid.rotate_left() + + # Process occluders and visibility + # Note that this incurs some performance cost + if not self.see_through_walls: + vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2 , self.agent_view_size - 1)) + else: + vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool) + + # Make it so the agent sees what it's carrying + # We do this by placing the carried object at the agent's position + # in the agent's partially observable view + agent_pos = grid.width // 2, grid.height - 1 + if self.carrying: + grid.set(*agent_pos, self.carrying) + else: + grid.set(*agent_pos, None) + + return grid, vis_mask + + def gen_obs(self): + """ + Generate the agent's view (partially observable, low-resolution encoding) + """ + + grid, vis_mask = self.gen_obs_grid() + + # Encode the partially observable view into a numpy array + image = grid.encode(vis_mask) + + assert hasattr(self, 'mission'), "environments must define a textual mission string" + + # Observations are dictionaries containing: + # - an image (partially observable view of the environment) + # - the agent's direction/orientation (acting as a compass) + # - a textual mission string (instructions for the agent) + obs = { + 'image': image, + 'direction': self.agent_dir, + 'mission': self.mission + } + + return obs + + def get_obs_render(self, obs, tile_size=TILE_PIXELS//2): + """ + Render an agent observation for visualization + """ + + grid, vis_mask = Grid.decode(obs) + + # Render the whole grid + img = grid.render( + tile_size, + agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1), + agent_dir=3, + highlight_mask=vis_mask + ) + + return img + + def render(self, mode='human', close=False, highlight=True, tile_size=TILE_PIXELS): + """ + Render the whole-grid human view + """ + + if close: + if self.window: + self.window.close() + return + + if mode == 'human' and not self.window: + import d4rl.gym_minigrid.window + self.window = d4rl.gym_minigrid.window.Window('gym_minigrid') + self.window.show(block=False) + + # Compute which cells are visible to the agent + _, vis_mask = self.gen_obs_grid() + + # Compute the world coordinates of the bottom-left corner + # of the agent's view area + f_vec = self.dir_vec + r_vec = self.right_vec + top_left = self.agent_pos + f_vec * (self.agent_view_size-1) - r_vec * (self.agent_view_size // 2) + + # Mask of which cells to highlight + highlight_mask = np.zeros(shape=(self.width, self.height), dtype=np.bool) + + # For each cell in the visibility mask + for vis_j in range(0, self.agent_view_size): + for vis_i in range(0, self.agent_view_size): + # If this cell is not visible, don't highlight it + if not vis_mask[vis_i, vis_j]: + continue + + # Compute the world coordinates of this cell + abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i) + + if abs_i < 0 or abs_i >= self.width: + continue + if abs_j < 0 or abs_j >= self.height: + continue + + # Mark this cell to be highlighted + highlight_mask[abs_i, abs_j] = True + + # Render the whole grid + img = self.grid.render( + tile_size, + self.agent_pos, + self.agent_dir, + highlight_mask=highlight_mask if highlight else None + ) + + if mode == 'human': + self.window.show_img(img) + self.window.set_caption(self.mission) + + return img diff --git a/d4rl/d4rl/gym_minigrid/register.py b/d4rl/d4rl/gym_minigrid/register.py new file mode 100644 index 0000000..cd56774 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/register.py @@ -0,0 +1,21 @@ +from gym.envs.registration import register as gym_register + +env_list = [] + +def register( + id, + entry_point, + reward_threshold=0.95 +): + assert id.startswith("MiniGrid-") + assert id not in env_list + + # Register the environment with OpenAI gym + gym_register( + id=id, + entry_point=entry_point, + reward_threshold=reward_threshold + ) + + # Add the environment to the set + env_list.append(id) diff --git a/d4rl/d4rl/gym_minigrid/rendering.py b/d4rl/d4rl/gym_minigrid/rendering.py new file mode 100644 index 0000000..dd11074 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/rendering.py @@ -0,0 +1,118 @@ +import math +import numpy as np + +def downsample(img, factor): + """ + Downsample an image along both dimensions by some factor + """ + + assert img.shape[0] % factor == 0 + assert img.shape[1] % factor == 0 + + img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3]) + img = img.mean(axis=3) + img = img.mean(axis=1) + + return img + +def fill_coords(img, fn, color): + """ + Fill pixels of an image with coordinates matching a filter function + """ + + for y in range(img.shape[0]): + for x in range(img.shape[1]): + yf = (y + 0.5) / img.shape[0] + xf = (x + 0.5) / img.shape[1] + if fn(xf, yf): + img[y, x] = color + + return img + +def rotate_fn(fin, cx, cy, theta): + def fout(x, y): + x = x - cx + y = y - cy + + x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) + y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) + + return fin(x2, y2) + + return fout + +def point_in_line(x0, y0, x1, y1, r): + p0 = np.array([x0, y0]) + p1 = np.array([x1, y1]) + dir = p1 - p0 + dist = np.linalg.norm(dir) + dir = dir / dist + + xmin = min(x0, x1) - r + xmax = max(x0, x1) + r + ymin = min(y0, y1) - r + ymax = max(y0, y1) + r + + def fn(x, y): + # Fast, early escape test + if x < xmin or x > xmax or y < ymin or y > ymax: + return False + + q = np.array([x, y]) + pq = q - p0 + + # Closest point on line + a = np.dot(pq, dir) + a = np.clip(a, 0, dist) + p = p0 + a * dir + + dist_to_line = np.linalg.norm(q - p) + return dist_to_line <= r + + return fn + +def point_in_circle(cx, cy, r): + def fn(x, y): + return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r + return fn + +def point_in_rect(xmin, xmax, ymin, ymax): + def fn(x, y): + return x >= xmin and x <= xmax and y >= ymin and y <= ymax + return fn + +def point_in_triangle(a, b, c): + a = np.array(a) + b = np.array(b) + c = np.array(c) + + def fn(x, y): + v0 = c - a + v1 = b - a + v2 = np.array((x, y)) - a + + # Compute dot products + dot00 = np.dot(v0, v0) + dot01 = np.dot(v0, v1) + dot02 = np.dot(v0, v2) + dot11 = np.dot(v1, v1) + dot12 = np.dot(v1, v2) + + # Compute barycentric coordinates + inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) + u = (dot11 * dot02 - dot01 * dot12) * inv_denom + v = (dot00 * dot12 - dot01 * dot02) * inv_denom + + # Check if point is in triangle + return (u >= 0) and (v >= 0) and (u + v) < 1 + + return fn + +def highlight_img(img, color=(255, 255, 255), alpha=0.30): + """ + Add highlighting to an image + """ + + blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) + blend_img = blend_img.clip(0, 255).astype(np.uint8) + img[:, :, :] = blend_img diff --git a/d4rl/d4rl/gym_minigrid/roomgrid.py b/d4rl/d4rl/gym_minigrid/roomgrid.py new file mode 100644 index 0000000..81e7d7a --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/roomgrid.py @@ -0,0 +1,397 @@ +from d4rl.gym_minigrid.minigrid import * + +def reject_next_to(env, pos): + """ + Function to filter out object positions that are right next to + the agent's starting point + """ + + sx, sy = env.agent_pos + x, y = pos + d = abs(sx - x) + abs(sy - y) + return d < 2 + +class Room: + def __init__( + self, + top, + size + ): + # Top-left corner and size (tuples) + self.top = top + self.size = size + + # List of door objects and door positions + # Order of the doors is right, down, left, up + self.doors = [None] * 4 + self.door_pos = [None] * 4 + + # List of rooms adjacent to this one + # Order of the neighbors is right, down, left, up + self.neighbors = [None] * 4 + + # Indicates if this room is behind a locked door + self.locked = False + + # List of objects contained + self.objs = [] + + def rand_pos(self, env): + topX, topY = self.top + sizeX, sizeY = self.size + return env._randPos( + topX + 1, topX + sizeX - 1, + topY + 1, topY + sizeY - 1 + ) + + def pos_inside(self, x, y): + """ + Check if a position is within the bounds of this room + """ + + topX, topY = self.top + sizeX, sizeY = self.size + + if x < topX or y < topY: + return False + + if x >= topX + sizeX or y >= topY + sizeY: + return False + + return True + +class RoomGrid(MiniGridEnv): + """ + Environment with multiple rooms and random objects. + This is meant to serve as a base class for other environments. + """ + + def __init__( + self, + room_size=7, + num_rows=3, + num_cols=3, + max_steps=100, + seed=0 + ): + assert room_size > 0 + assert room_size >= 3 + assert num_rows > 0 + assert num_cols > 0 + self.room_size = room_size + self.num_rows = num_rows + self.num_cols = num_cols + + height = (room_size - 1) * num_rows + 1 + width = (room_size - 1) * num_cols + 1 + + # By default, this environment has no mission + self.mission = '' + + super().__init__( + width=width, + height=height, + max_steps=max_steps, + see_through_walls=False, + seed=seed + ) + + def room_from_pos(self, x, y): + """Get the room a given position maps to""" + + assert x >= 0 + assert y >= 0 + + i = x // (self.room_size-1) + j = y // (self.room_size-1) + + assert i < self.num_cols + assert j < self.num_rows + + return self.room_grid[j][i] + + def get_room(self, i, j): + assert i < self.num_cols + assert j < self.num_rows + return self.room_grid[j][i] + + def _gen_grid(self, width, height): + # Create the grid + self.grid = Grid(width, height) + + self.room_grid = [] + + # For each row of rooms + for j in range(0, self.num_rows): + row = [] + + # For each column of rooms + for i in range(0, self.num_cols): + room = Room( + (i * (self.room_size-1), j * (self.room_size-1)), + (self.room_size, self.room_size) + ) + row.append(room) + + # Generate the walls for this room + self.grid.wall_rect(*room.top, *room.size) + + self.room_grid.append(row) + + # For each row of rooms + for j in range(0, self.num_rows): + # For each column of rooms + for i in range(0, self.num_cols): + room = self.room_grid[j][i] + + x_l, y_l = (room.top[0] + 1, room.top[1] + 1) + x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1) + + # Door positions, order is right, down, left, up + if i < self.num_cols - 1: + room.neighbors[0] = self.room_grid[j][i+1] + room.door_pos[0] = (x_m, self._rand_int(y_l, y_m)) + if j < self.num_rows - 1: + room.neighbors[1] = self.room_grid[j+1][i] + room.door_pos[1] = (self._rand_int(x_l, x_m), y_m) + if i > 0: + room.neighbors[2] = self.room_grid[j][i-1] + room.door_pos[2] = room.neighbors[2].door_pos[0] + if j > 0: + room.neighbors[3] = self.room_grid[j-1][i] + room.door_pos[3] = room.neighbors[3].door_pos[1] + + # The agent starts in the middle, facing right + self.agent_pos = ( + (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2), + (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2) + ) + self.agent_dir = 0 + + def place_in_room(self, i, j, obj): + """ + Add an existing object to room (i, j) + """ + + room = self.get_room(i, j) + + pos = self.place_obj( + obj, + room.top, + room.size, + reject_fn=reject_next_to, + max_tries=1000 + ) + + room.objs.append(obj) + + return obj, pos + + def add_object(self, i, j, kind=None, color=None): + """ + Add a new object to room (i, j) + """ + + if kind == None: + kind = self._rand_elem(['key', 'ball', 'box']) + + if color == None: + color = self._rand_color() + + # TODO: we probably want to add an Object.make helper function + assert kind in ['key', 'ball', 'box'] + if kind == 'key': + obj = Key(color) + elif kind == 'ball': + obj = Ball(color) + elif kind == 'box': + obj = Box(color) + + return self.place_in_room(i, j, obj) + + def add_door(self, i, j, door_idx=None, color=None, locked=None): + """ + Add a door to a room, connecting it to a neighbor + """ + + room = self.get_room(i, j) + + if door_idx == None: + # Need to make sure that there is a neighbor along this wall + # and that there is not already a door + while True: + door_idx = self._rand_int(0, 4) + if room.neighbors[door_idx] and room.doors[door_idx] is None: + break + + if color == None: + color = self._rand_color() + + if locked is None: + locked = self._rand_bool() + + assert room.doors[door_idx] is None, "door already exists" + + room.locked = locked + door = Door(color, is_locked=locked) + + pos = room.door_pos[door_idx] + self.grid.set(*pos, door) + door.cur_pos = pos + + neighbor = room.neighbors[door_idx] + room.doors[door_idx] = door + neighbor.doors[(door_idx+2) % 4] = door + + return door, pos + + def remove_wall(self, i, j, wall_idx): + """ + Remove a wall between two rooms + """ + + room = self.get_room(i, j) + + assert wall_idx >= 0 and wall_idx < 4 + assert room.doors[wall_idx] is None, "door exists on this wall" + assert room.neighbors[wall_idx], "invalid wall" + + neighbor = room.neighbors[wall_idx] + + tx, ty = room.top + w, h = room.size + + # Ordering of walls is right, down, left, up + if wall_idx == 0: + for i in range(1, h - 1): + self.grid.set(tx + w - 1, ty + i, None) + elif wall_idx == 1: + for i in range(1, w - 1): + self.grid.set(tx + i, ty + h - 1, None) + elif wall_idx == 2: + for i in range(1, h - 1): + self.grid.set(tx, ty + i, None) + elif wall_idx == 3: + for i in range(1, w - 1): + self.grid.set(tx + i, ty, None) + else: + assert False, "invalid wall index" + + # Mark the rooms as connected + room.doors[wall_idx] = True + neighbor.doors[(wall_idx+2) % 4] = True + + def place_agent(self, i=None, j=None, rand_dir=True): + """ + Place the agent in a room + """ + + if i == None: + i = self._rand_int(0, self.num_cols) + if j == None: + j = self._rand_int(0, self.num_rows) + + room = self.room_grid[j][i] + + # Find a position that is not right in front of an object + while True: + super().place_agent(room.top, room.size, rand_dir, max_tries=1000) + front_cell = self.grid.get(*self.front_pos) + if front_cell is None or front_cell.type is 'wall': + break + + return self.agent_pos + + def connect_all(self, door_colors=COLOR_NAMES, max_itrs=5000): + """ + Make sure that all rooms are reachable by the agent from its + starting position + """ + + start_room = self.room_from_pos(*self.agent_pos) + + added_doors = [] + + def find_reach(): + reach = set() + stack = [start_room] + while len(stack) > 0: + room = stack.pop() + if room in reach: + continue + reach.add(room) + for i in range(0, 4): + if room.doors[i]: + stack.append(room.neighbors[i]) + return reach + + num_itrs = 0 + + while True: + # This is to handle rare situations where random sampling produces + # a level that cannot be connected, producing in an infinite loop + if num_itrs > max_itrs: + raise RecursionError('connect_all failed') + num_itrs += 1 + + # If all rooms are reachable, stop + reach = find_reach() + if len(reach) == self.num_rows * self.num_cols: + break + + # Pick a random room and door position + i = self._rand_int(0, self.num_cols) + j = self._rand_int(0, self.num_rows) + k = self._rand_int(0, 4) + room = self.get_room(i, j) + + # If there is already a door there, skip + if not room.door_pos[k] or room.doors[k]: + continue + + if room.locked or room.neighbors[k].locked: + continue + + color = self._rand_elem(door_colors) + door, _ = self.add_door(i, j, k, color, False) + added_doors.append(door) + + return added_doors + + def add_distractors(self, i=None, j=None, num_distractors=10, all_unique=True): + """ + Add random objects that can potentially distract/confuse the agent. + """ + + # Collect a list of existing objects + objs = [] + for row in self.room_grid: + for room in row: + for obj in room.objs: + objs.append((obj.type, obj.color)) + + # List of distractors added + dists = [] + + while len(dists) < num_distractors: + color = self._rand_elem(COLOR_NAMES) + type = self._rand_elem(['key', 'ball', 'box']) + obj = (type, color) + + if all_unique and obj in objs: + continue + + # Add the object to a random room if no room specified + room_i = i + room_j = j + if room_i == None: + room_i = self._rand_int(0, self.num_cols) + if room_j == None: + room_j = self._rand_int(0, self.num_rows) + + dist, pos = self.add_object(room_i, room_j, *obj) + + objs.append(obj) + dists.append(dist) + + return dists diff --git a/d4rl/d4rl/gym_minigrid/window.py b/d4rl/d4rl/gym_minigrid/window.py new file mode 100644 index 0000000..d1abb3a --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/window.py @@ -0,0 +1,90 @@ +import sys +import numpy as np + +# Only ask users to install matplotlib if they actually need it +try: + import matplotlib.pyplot as plt +except: + print('To display the environment in a window, please install matplotlib, eg:') + print('pip3 install --user matplotlib') + sys.exit(-1) + +class Window: + """ + Window to draw a gridworld instance using Matplotlib + """ + + def __init__(self, title): + self.fig = None + + self.imshow_obj = None + + # Create the figure and axes + self.fig, self.ax = plt.subplots() + + # Show the env name in the window title + self.fig.canvas.set_window_title(title) + + # Turn off x/y axis numbering/ticks + self.ax.set_xticks([], []) + self.ax.set_yticks([], []) + + # Flag indicating the window was closed + self.closed = False + + def close_handler(evt): + self.closed = True + + self.fig.canvas.mpl_connect('close_event', close_handler) + + def show_img(self, img): + """ + Show an image or update the image being shown + """ + + # Show the first image of the environment + if self.imshow_obj is None: + self.imshow_obj = self.ax.imshow(img, interpolation='bilinear') + + self.imshow_obj.set_data(img) + self.fig.canvas.draw() + + # Let matplotlib process UI events + # This is needed for interactive mode to work properly + plt.pause(0.001) + + def set_caption(self, text): + """ + Set/update the caption text below the image + """ + + plt.xlabel(text) + + def reg_key_handler(self, key_handler): + """ + Register a keyboard event handler + """ + + # Keyboard handler + self.fig.canvas.mpl_connect('key_press_event', key_handler) + + def show(self, block=True): + """ + Show the window, and start an event loop + """ + + # If not blocking, trigger interactive mode + if not block: + plt.ion() + + # Show the plot + # In non-interative mode, this enters the matplotlib event loop + # In interactive mode, this call does not block + plt.show() + + def close(self): + """ + Close the window + """ + + plt.close() diff --git a/d4rl/d4rl/gym_minigrid/wrappers.py b/d4rl/d4rl/gym_minigrid/wrappers.py new file mode 100644 index 0000000..fe229c5 --- /dev/null +++ b/d4rl/d4rl/gym_minigrid/wrappers.py @@ -0,0 +1,330 @@ +import math +import operator +from functools import reduce + +import numpy as np +import gym +from gym import error, spaces, utils +from d4rl.gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX + +class ReseedWrapper(gym.core.Wrapper): + """ + Wrapper to always regenerate an environment with the same set of seeds. + This can be used to force an environment to always keep the same + configuration when reset. + """ + + def __init__(self, env, seeds=[0], seed_idx=0): + self.seeds = list(seeds) + self.seed_idx = seed_idx + super().__init__(env) + + def reset(self, **kwargs): + seed = self.seeds[self.seed_idx] + self.seed_idx = (self.seed_idx + 1) % len(self.seeds) + self.env.seed(seed) + return self.env.reset(**kwargs) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + return obs, reward, done, info + +class ActionBonus(gym.core.Wrapper): + """ + Wrapper which adds an exploration bonus. + This is a reward to encourage exploration of less + visited (state,action) pairs. + """ + + def __init__(self, env): + super().__init__(env) + self.counts = {} + + def step(self, action): + obs, reward, done, info = self.env.step(action) + + env = self.unwrapped + tup = (tuple(env.agent_pos), env.agent_dir, action) + + # Get the count for this (s,a) pair + pre_count = 0 + if tup in self.counts: + pre_count = self.counts[tup] + + # Update the count for this (s,a) pair + new_count = pre_count + 1 + self.counts[tup] = new_count + + bonus = 1 / math.sqrt(new_count) + reward += bonus + + return obs, reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + +class StateBonus(gym.core.Wrapper): + """ + Adds an exploration bonus based on which positions + are visited on the grid. + """ + + def __init__(self, env): + super().__init__(env) + self.counts = {} + + def step(self, action): + obs, reward, done, info = self.env.step(action) + + # Tuple based on which we index the counts + # We use the position after an update + env = self.unwrapped + tup = (tuple(env.agent_pos)) + + # Get the count for this key + pre_count = 0 + if tup in self.counts: + pre_count = self.counts[tup] + + # Update the count for this key + new_count = pre_count + 1 + self.counts[tup] = new_count + + bonus = 1 / math.sqrt(new_count) + reward += bonus + + return obs, reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + +class ImgObsWrapper(gym.core.ObservationWrapper): + """ + Use the image as the only observation output, no language/mission. + """ + + def __init__(self, env): + super().__init__(env) + self.observation_space = env.observation_space.spaces['image'] + + def observation(self, obs): + return obs['image'] + +class OneHotPartialObsWrapper(gym.core.ObservationWrapper): + """ + Wrapper to get a one-hot encoding of a partially observable + agent view as observation. + """ + + def __init__(self, env, tile_size=8): + super().__init__(env) + + self.tile_size = tile_size + + obs_shape = env.observation_space['image'].shape + + # Number of bits per cell + num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX) + + self.observation_space.spaces["image"] = spaces.Box( + low=0, + high=255, + shape=(obs_shape[0], obs_shape[1], num_bits), + dtype='uint8' + ) + + def observation(self, obs): + img = obs['image'] + out = np.zeros(self.observation_space.shape, dtype='uint8') + + for i in range(img.shape[0]): + for j in range(img.shape[1]): + type = img[i, j, 0] + color = img[i, j, 1] + state = img[i, j, 2] + + out[i, j, type] = 1 + out[i, j, len(OBJECT_TO_IDX) + color] = 1 + out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1 + + return { + 'mission': obs['mission'], + 'image': out + } + +class RGBImgObsWrapper(gym.core.ObservationWrapper): + """ + Wrapper to use fully observable RGB image as the only observation output, + no language/mission. This can be used to have the agent to solve the + gridworld in pixel space. + """ + + def __init__(self, env, tile_size=8): + super().__init__(env) + + self.tile_size = tile_size + + self.observation_space.spaces['image'] = spaces.Box( + low=0, + high=255, + shape=(self.env.width*tile_size, self.env.height*tile_size, 3), + dtype='uint8' + ) + + def observation(self, obs): + env = self.unwrapped + + rgb_img = env.render( + mode='rgb_array', + highlight=False, + tile_size=self.tile_size + ) + + return { + 'mission': obs['mission'], + 'image': rgb_img + } + + +class RGBImgPartialObsWrapper(gym.core.ObservationWrapper): + """ + Wrapper to use partially observable RGB image as the only observation output + This can be used to have the agent to solve the gridworld in pixel space. + """ + + def __init__(self, env, tile_size=8): + super().__init__(env) + + self.tile_size = tile_size + + obs_shape = env.observation_space['image'].shape + self.observation_space.spaces['image'] = spaces.Box( + low=0, + high=255, + shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3), + dtype='uint8' + ) + + def observation(self, obs): + env = self.unwrapped + + rgb_img_partial = env.get_obs_render( + obs['image'], + tile_size=self.tile_size + ) + + return { + 'mission': obs['mission'], + 'image': rgb_img_partial + } + +class FullyObsWrapper(gym.core.ObservationWrapper): + """ + Fully observable gridworld using a compact grid encoding + """ + + def __init__(self, env): + super().__init__(env) + + self.observation_space.spaces["image"] = spaces.Box( + low=0, + high=255, + shape=(self.env.width, self.env.height, 3), # number of cells + dtype='uint8' + ) + + def observation(self, obs): + env = self.unwrapped + full_grid = env.grid.encode() + full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ + OBJECT_TO_IDX['agent'], + COLOR_TO_IDX['red'], + env.agent_dir + ]) + + return { + 'mission': obs['mission'], + 'image': full_grid + } + +class FlatObsWrapper(gym.core.ObservationWrapper): + """ + Encode mission strings using a one-hot scheme, + and combine these with observed images into one flat array + """ + + def __init__(self, env, maxStrLen=96): + super().__init__(env) + + self.maxStrLen = maxStrLen + self.numCharCodes = 27 + + imgSpace = env.observation_space.spaces['image'] + imgSize = reduce(operator.mul, imgSpace.shape, 1) + + self.observation_space = spaces.Box( + low=0, + high=255, + shape=(1, imgSize + self.numCharCodes * self.maxStrLen), + dtype='uint8' + ) + + self.cachedStr = None + self.cachedArray = None + + def observation(self, obs): + image = obs['image'] + mission = obs['mission'] + + # Cache the last-encoded mission string + if mission != self.cachedStr: + assert len(mission) <= self.maxStrLen, 'mission string too long ({} chars)'.format(len(mission)) + mission = mission.lower() + + strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32') + + for idx, ch in enumerate(mission): + if ch >= 'a' and ch <= 'z': + chNo = ord(ch) - ord('a') + elif ch == ' ': + chNo = ord('z') - ord('a') + 1 + assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo) + strArray[idx, chNo] = 1 + + self.cachedStr = mission + self.cachedArray = strArray + + obs = np.concatenate((image.flatten(), self.cachedArray.flatten())) + + return obs + +class ViewSizeWrapper(gym.core.Wrapper): + """ + Wrapper to customize the agent field of view size. + This cannot be used with fully observable wrappers. + """ + + def __init__(self, env, agent_view_size=7): + super().__init__(env) + + # Override default view size + env.unwrapped.agent_view_size = agent_view_size + + # Compute observation space with specified view size + observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(agent_view_size, agent_view_size, 3), + dtype='uint8' + ) + + # Override the environment's observation space + self.observation_space = spaces.Dict({ + 'image': observation_space + }) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + def step(self, action): + return self.env.step(action) diff --git a/d4rl/d4rl/gym_mujoco/__init__.py b/d4rl/d4rl/gym_mujoco/__init__.py new file mode 100644 index 0000000..baa19b6 --- /dev/null +++ b/d4rl/d4rl/gym_mujoco/__init__.py @@ -0,0 +1,286 @@ +from gym.envs.registration import register +from d4rl.gym_mujoco import gym_envs +from d4rl import infos + +# V1 envs +for agent in ['hopper', 'halfcheetah', 'ant', 'walker2d']: + for dataset in ['random', 'medium', 'expert', 'medium-expert', 'medium-replay', 'full-replay']: + for version in ['v1', 'v2']: + env_name = '%s-%s-%s' % (agent, dataset, version) + register( + id=env_name, + entry_point='d4rl.gym_mujoco.gym_envs:get_%s_env' % agent.replace('halfcheetah', 'cheetah').replace('walker2d', 'walker'), + max_episode_steps=1000, + kwargs={ + 'deprecated': version != 'v2', + 'ref_min_score': infos.REF_MIN_SCORE[env_name], + 'ref_max_score': infos.REF_MAX_SCORE[env_name], + 'dataset_url': infos.DATASET_URLS[env_name] + } + ) + + +HOPPER_RANDOM_SCORE = -20.272305 +HALFCHEETAH_RANDOM_SCORE = -280.178953 +WALKER_RANDOM_SCORE = 1.629008 +ANT_RANDOM_SCORE = -325.6 + +HOPPER_EXPERT_SCORE = 3234.3 +HALFCHEETAH_EXPERT_SCORE = 12135.0 +WALKER_EXPERT_SCORE = 4592.3 +ANT_EXPERT_SCORE = 3879.7 + +# Single Policy datasets +register( + id='hopper-medium-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HOPPER_RANDOM_SCORE, + 'ref_max_score': HOPPER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5' + } +) + +register( + id='halfcheetah-medium-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, + 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5' + } +) + +register( + id='walker2d-medium-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': WALKER_RANDOM_SCORE, + 'ref_max_score': WALKER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5' + } +) + +register( + id='hopper-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HOPPER_RANDOM_SCORE, + 'ref_max_score': HOPPER_EXPERT_SCORE, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5' + } +) + +register( + id='halfcheetah-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, + 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5' + } +) + +register( + id='walker2d-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': WALKER_RANDOM_SCORE, + 'ref_max_score': WALKER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5' + } +) + +register( + id='hopper-random-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HOPPER_RANDOM_SCORE, + 'ref_max_score': HOPPER_EXPERT_SCORE, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5' + } +) + +register( + id='halfcheetah-random-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, + 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5' + } +) + +register( + id='walker2d-random-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': WALKER_RANDOM_SCORE, + 'ref_max_score': WALKER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5' + } +) + +# Mixed datasets +register( + id='hopper-medium-replay-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HOPPER_RANDOM_SCORE, + 'ref_max_score': HOPPER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5' + }, +) + +register( + id='walker2d-medium-replay-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': WALKER_RANDOM_SCORE, + 'ref_max_score': WALKER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5' + } +) + +register( + id='halfcheetah-medium-replay-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, + 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5' + } +) + +# Mixtures of random/medium and experts +register( + id='walker2d-medium-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_walker_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': WALKER_RANDOM_SCORE, + 'ref_max_score': WALKER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5' + } +) + +register( + id='halfcheetah-medium-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_cheetah_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HALFCHEETAH_RANDOM_SCORE, + 'ref_max_score': HALFCHEETAH_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5' + } +) + +register( + id='hopper-medium-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_hopper_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': HOPPER_RANDOM_SCORE, + 'ref_max_score': HOPPER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5' + } +) + +register( + id='ant-medium-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5' + } +) + +register( + id='ant-medium-replay-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5' + } +) + +register( + id='ant-medium-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5' + } +) + +register( + id='ant-random-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5' + } +) + +register( + id='ant-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5' + } +) + +register( + id='ant-random-expert-v0', + entry_point='d4rl.gym_mujoco.gym_envs:get_ant_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'ref_min_score': ANT_RANDOM_SCORE, + 'ref_max_score': ANT_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5' + } +) diff --git a/d4rl/d4rl/gym_mujoco/gym_envs.py b/d4rl/d4rl/gym_mujoco/gym_envs.py new file mode 100644 index 0000000..6e9a34c --- /dev/null +++ b/d4rl/d4rl/gym_mujoco/gym_envs.py @@ -0,0 +1,40 @@ +from .. import offline_env +from gym.envs.mujoco import HalfCheetahEnv, AntEnv, HopperEnv, Walker2dEnv +from ..utils.wrappers import NormalizedBoxEnv + +class OfflineAntEnv(AntEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + AntEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineHopperEnv(HopperEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + HopperEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineHalfCheetahEnv(HalfCheetahEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + HalfCheetahEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + +class OfflineWalker2dEnv(Walker2dEnv, offline_env.OfflineEnv): + def __init__(self, **kwargs): + Walker2dEnv.__init__(self,) + offline_env.OfflineEnv.__init__(self, **kwargs) + + +def get_ant_env(**kwargs): + return NormalizedBoxEnv(OfflineAntEnv(**kwargs)) + +def get_cheetah_env(**kwargs): + return NormalizedBoxEnv(OfflineHalfCheetahEnv(**kwargs)) + +def get_hopper_env(**kwargs): + return NormalizedBoxEnv(OfflineHopperEnv(**kwargs)) + +def get_walker_env(**kwargs): + return NormalizedBoxEnv(OfflineWalker2dEnv(**kwargs)) + +if __name__ == '__main__': + """Example usage of these envs""" + pass diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore b/d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore new file mode 100644 index 0000000..5509140 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/.gitignore @@ -0,0 +1 @@ +*.DS_Store diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml new file mode 100644 index 0000000..d5a2a13 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand.xml @@ -0,0 +1,58 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml new file mode 100644 index 0000000..ab87c68 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/Adroit_hand_withOverlay.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE b/d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/README.md b/d4rl/d4rl/hand_manipulation_suite/Adroit/README.md new file mode 100644 index 0000000..ad5451c --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/README.md @@ -0,0 +1,29 @@ +# Adroit Manipulation Platform + +Adroit manipulation platform is reconfigurable, tendon-driven, pneumatically-actuated platform designed and developed by [Vikash Kumar](https://vikashplus.github.io/) during this Ph.D. ([Thesis: Manipulators and Manipulation in high dimensional spaces](https://digital.lib.washington.edu/researchworks/handle/1773/38104)) to study dynamic dexterous manipulation. Adroit is comprised of the [Shadow Hand](https://www.shadowrobot.com/products/dexterous-hand/) skeleton (developed by [Shadow Robot company](https://www.shadowrobot.com/)) and a custom arm, and is powered by a custom actuation sysem. This custom actuation system allows Adroit to move the ShadowHand skeleton faster than a human hand (70 msec limit-to-limit movement, 30 msec overall reflex latency), generate sufficient forces (40 N at each finger tendon, 125N at each wrist tendon), and achieve high compliance on the mechanism level (6 grams of external force at the fingertip displaces the finger when the system is powered.) This combination of speed, force, and compliance is a prerequisite for dexterous manipulation, yet it has never before been achieved with a tendon-driven system, let alone a system with 24 degrees of freedom and 40 tendons. + +## Mujoco Model +Adroit is a 28 degree of freedom system which consists of a 24 degrees of freedom **ShadowHand** and a 4 degree of freedom arm. This repository contains the Mujoco Models of the system developed with extreme care and great attention to the details. + + +## In Projects +Adroit has been used in a wide variety of project. A small list is appended below. Details of these projects can be found [here](https://vikashplus.github.io/). +[![projects](https://github.com/vikashplus/Adroit/blob/master/gallery/projects.JPG)](https://vikashplus.github.io/) +## In News and Media +Adroit has found quite some attention in the world media. Details can be found [here](https://vikashplus.github.io/news.html) + +[![News](https://github.com/vikashplus/Adroit/blob/master/gallery/news.JPG)](https://vikashplus.github.io/news.html) + + +## Citation +If the contents of this repo helped you, please consider citing + +``` +@phdthesis{Kumar2016thesis, + title = {Manipulators and Manipulation in high dimensional spaces}, + school = {University of Washington, Seattle}, + author = {Kumar, Vikash}, + year = {2016}, + url = {https://digital.lib.washington.edu/researchworks/handle/1773/38104} +} +``` diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/news.JPG b/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/news.JPG new file mode 100644 index 0000000..e53b317 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/news.JPG differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/projects.JPG b/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/projects.JPG new file mode 100644 index 0000000..1360a90 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/gallery/projects.JPG differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml new file mode 100644 index 0000000..c2a85eb --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/assets.xml @@ -0,0 +1,345 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml new file mode 100644 index 0000000..ba2353a --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain.xml @@ -0,0 +1,226 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml new file mode 100644 index 0000000..7db7a10 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/chain1.xml @@ -0,0 +1,227 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml new file mode 100644 index 0000000..0a17b5c --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/joint_position_actuation.xml @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F1.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F1.stl new file mode 100644 index 0000000..515d3c9 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F1.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F2.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F2.stl new file mode 100644 index 0000000..7bc5e20 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F2.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F3.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F3.stl new file mode 100644 index 0000000..223f06f Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/F3.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH1_z.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH1_z.stl new file mode 100644 index 0000000..400ee2d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH1_z.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH2_z.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH2_z.stl new file mode 100644 index 0000000..5ace838 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH2_z.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH3_z.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH3_z.stl new file mode 100644 index 0000000..23485ab Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/TH3_z.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_base.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_base.stl new file mode 100644 index 0000000..d9a26d5 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_base.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk.stl new file mode 100644 index 0000000..fb2ee17 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk_asmbly.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk_asmbly.stl new file mode 100644 index 0000000..bafae4e Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/arm_trunk_asmbly.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/distal_ellipsoid.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/distal_ellipsoid.stl new file mode 100644 index 0000000..8906519 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/distal_ellipsoid.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_flex.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_flex.stl new file mode 100644 index 0000000..b004e74 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_flex.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_motor.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_motor.stl new file mode 100644 index 0000000..4e849bb Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_motor.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_muscle.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_muscle.stl new file mode 100644 index 0000000..ee6470c Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/elbow_rotate_muscle.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly(muscle_cone).stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly(muscle_cone).stl new file mode 100644 index 0000000..6ce05b9 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly(muscle_cone).stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly.stl new file mode 100644 index 0000000..0945410 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_Cy_PlateAsmbly.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_PlateAsmbly.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_PlateAsmbly.stl new file mode 100644 index 0000000..85491bb Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_PlateAsmbly.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric.stl new file mode 100644 index 0000000..80f6f3d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric_cvx.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric_cvx.stl new file mode 100644 index 0000000..3c30f57 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_electric_cvx.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_muscle.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_muscle.stl new file mode 100644 index 0000000..c47c510 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_muscle.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple.stl new file mode 100644 index 0000000..888d2d3 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple_cvx.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple_cvx.stl new file mode 100644 index 0000000..1f133ca Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_simple_cvx.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_weight.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_weight.stl new file mode 100644 index 0000000..515050a Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/forearm_weight.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/knuckle.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/knuckle.stl new file mode 100644 index 0000000..4faedd7 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/knuckle.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/lfmetacarpal.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/lfmetacarpal.stl new file mode 100644 index 0000000..535cf4d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/lfmetacarpal.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/palm.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/palm.stl new file mode 100644 index 0000000..65e47eb Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/palm.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm.stl new file mode 100644 index 0000000..5045e82 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_asmbl_shoulder.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_asmbl_shoulder.stl new file mode 100644 index 0000000..e6ffe6d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_asmbl_shoulder.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_ass.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_ass.stl new file mode 100644 index 0000000..cbf0b74 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/upper_arm_ass.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/wrist.stl b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/wrist.stl new file mode 100644 index 0000000..420d5f9 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/meshes/wrist.stl differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml new file mode 100644 index 0000000..1572f74 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/tendon_torque_actuation.xml @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/darkwood.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/darkwood.png new file mode 100644 index 0000000..d5dcc5c Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/darkwood.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/dice.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/dice.png new file mode 100644 index 0000000..798a8e0 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/dice.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/foil.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/foil.png new file mode 100644 index 0000000..654cfe1 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/foil.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/marble.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/marble.png new file mode 100644 index 0000000..c50e8b9 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/marble.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/silverRaw.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/silverRaw.png new file mode 100644 index 0000000..13690e5 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/silverRaw.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/skin.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/skin.png new file mode 100644 index 0000000..54e528d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/skin.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/square.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/square.png new file mode 100644 index 0000000..dbfd695 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/square.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/wood.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/wood.png new file mode 100644 index 0000000..c323cb9 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/wood.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/woodb.png b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/woodb.png new file mode 100644 index 0000000..47f94a8 Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/Adroit/resources/textures/woodb.png differ diff --git a/d4rl/d4rl/hand_manipulation_suite/__init__.py b/d4rl/d4rl/hand_manipulation_suite/__init__.py new file mode 100644 index 0000000..b5639f1 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/__init__.py @@ -0,0 +1,274 @@ +from gym.envs.registration import register +from mjrl.envs.mujoco_env import MujocoEnv +from d4rl.hand_manipulation_suite.door_v0 import DoorEnvV0 +from d4rl.hand_manipulation_suite.hammer_v0 import HammerEnvV0 +from d4rl.hand_manipulation_suite.pen_v0 import PenEnvV0 +from d4rl.hand_manipulation_suite.relocate_v0 import RelocateEnvV0 +from d4rl import infos + + +# V1 envs +MAX_STEPS = {'hammer': 200, 'relocate': 200, 'door': 200, 'pen': 100} +LONG_HORIZONS = {'hammer': 600, 'pen': 200, 'relocate': 500, 'door': 300} +ENV_MAPPING = {'hammer': 'HammerEnvV0', 'relocate': 'RelocateEnvV0', 'door': 'DoorEnvV0', 'pen': 'PenEnvV0'} +for agent in ['hammer', 'pen', 'relocate', 'door']: + for dataset in ['human', 'expert', 'cloned']: + env_name = '%s-%s-v1' % (agent, dataset) + register( + id=env_name, + entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent], + max_episode_steps=MAX_STEPS[agent], + kwargs={ + 'ref_min_score': infos.REF_MIN_SCORE[env_name], + 'ref_max_score': infos.REF_MAX_SCORE[env_name], + 'dataset_url': infos.DATASET_URLS[env_name] + } + ) + + if dataset == 'human': + longhorizon_env_name = '%s-human-longhorizon-v1' % agent + register( + id=longhorizon_env_name, + entry_point='d4rl.hand_manipulation_suite:' + ENV_MAPPING[agent], + max_episode_steps=LONG_HORIZONS[agent], + kwargs={ + 'ref_min_score': infos.REF_MIN_SCORE[env_name], + 'ref_max_score': infos.REF_MAX_SCORE[env_name], + 'dataset_url': infos.DATASET_URLS[env_name] + } + ) + +DOOR_RANDOM_SCORE = -56.512833 +DOOR_EXPERT_SCORE = 2880.5693087298737 + +HAMMER_RANDOM_SCORE = -274.856578 +HAMMER_EXPERT_SCORE = 12794.134825156867 + +PEN_RANDOM_SCORE = 96.262799 +PEN_EXPERT_SCORE = 3076.8331017826877 + +RELOCATE_RANDOM_SCORE = -6.425911 +RELOCATE_EXPERT_SCORE = 4233.877797728884 + +# Swing the door open +register( + id='door-v0', + entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', + max_episode_steps=200, +) + +register( + id='door-human-v0', + entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': DOOR_RANDOM_SCORE, + 'ref_max_score': DOOR_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5' + } +) + +register( + id='door-human-longhorizon-v0', + entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', + max_episode_steps=300, + kwargs={ + 'deprecated': True, + 'ref_min_score': DOOR_RANDOM_SCORE, + 'ref_max_score': DOOR_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5' + } +) + +register( + id='door-cloned-v0', + entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': DOOR_RANDOM_SCORE, + 'ref_max_score': DOOR_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5' + } +) + +register( + id='door-expert-v0', + entry_point='d4rl.hand_manipulation_suite:DoorEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': DOOR_RANDOM_SCORE, + 'ref_max_score': DOOR_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5' + } +) + +# Hammer a nail into the board +register( + id='hammer-v0', + entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', + max_episode_steps=200, +) + +register( + id='hammer-human-v0', + entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': HAMMER_RANDOM_SCORE, + 'ref_max_score': HAMMER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5' + } +) + +register( + id='hammer-human-longhorizon-v0', + entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', + max_episode_steps=600, + kwargs={ + 'deprecated': True, + 'ref_min_score': HAMMER_RANDOM_SCORE, + 'ref_max_score': HAMMER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5' + } +) + +register( + id='hammer-cloned-v0', + entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': HAMMER_RANDOM_SCORE, + 'ref_max_score': HAMMER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5' + } +) + +register( + id='hammer-expert-v0', + entry_point='d4rl.hand_manipulation_suite:HammerEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': HAMMER_RANDOM_SCORE, + 'ref_max_score': HAMMER_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5' + } +) + + +# Reposition a pen in hand +register( + id='pen-v0', + entry_point='d4rl.hand_manipulation_suite:PenEnvV0', + max_episode_steps=100, +) + +register( + id='pen-human-v0', + entry_point='d4rl.hand_manipulation_suite:PenEnvV0', + max_episode_steps=100, + kwargs={ + 'deprecated': True, + 'ref_min_score': PEN_RANDOM_SCORE, + 'ref_max_score': PEN_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5' + } +) + +register( + id='pen-human-longhorizon-v0', + entry_point='d4rl.hand_manipulation_suite:PenEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': PEN_RANDOM_SCORE, + 'ref_max_score': PEN_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5' + } +) + +register( + id='pen-cloned-v0', + entry_point='d4rl.hand_manipulation_suite:PenEnvV0', + max_episode_steps=100, + kwargs={ + 'deprecated': True, + 'ref_min_score': PEN_RANDOM_SCORE, + 'ref_max_score': PEN_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5' + } +) + +register( + id='pen-expert-v0', + entry_point='d4rl.hand_manipulation_suite:PenEnvV0', + max_episode_steps=100, + kwargs={ + 'deprecated': True, + 'ref_min_score': PEN_RANDOM_SCORE, + 'ref_max_score': PEN_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5' + } +) + + +# Relcoate an object to the target +register( + id='relocate-v0', + entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', + max_episode_steps=200, +) + +register( + id='relocate-human-v0', + entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': RELOCATE_RANDOM_SCORE, + 'ref_max_score': RELOCATE_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5' + } +) + +register( + id='relocate-human-longhorizon-v0', + entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', + max_episode_steps=500, + kwargs={ + 'deprecated': True, + 'ref_min_score': RELOCATE_RANDOM_SCORE, + 'ref_max_score': RELOCATE_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5' + } +) + +register( + id='relocate-cloned-v0', + entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': RELOCATE_RANDOM_SCORE, + 'ref_max_score': RELOCATE_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5' + } +) + +register( + id='relocate-expert-v0', + entry_point='d4rl.hand_manipulation_suite:RelocateEnvV0', + max_episode_steps=200, + kwargs={ + 'deprecated': True, + 'ref_min_score': RELOCATE_RANDOM_SCORE, + 'ref_max_score': RELOCATE_EXPERT_SCORE, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5' + } +) + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml new file mode 100644 index 0000000..3d00adc --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_Adroit.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml new file mode 100644 index 0000000..d17001d --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_assets.xml @@ -0,0 +1,345 @@ + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml new file mode 100644 index 0000000..831ff31 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_door.xml @@ -0,0 +1,92 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml new file mode 100644 index 0000000..f15ae28 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_hammer.xml @@ -0,0 +1,111 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml new file mode 100644 index 0000000..2abab7a --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_pen.xml @@ -0,0 +1,90 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml new file mode 100644 index 0000000..8f2e10a --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/assets/DAPG_relocate.xml @@ -0,0 +1,88 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/hand_manipulation_suite/assets/tasks.jpg b/d4rl/d4rl/hand_manipulation_suite/assets/tasks.jpg new file mode 100644 index 0000000..4b1499d Binary files /dev/null and b/d4rl/d4rl/hand_manipulation_suite/assets/tasks.jpg differ diff --git a/d4rl/d4rl/hand_manipulation_suite/door_v0.py b/d4rl/d4rl/hand_manipulation_suite/door_v0.py new file mode 100644 index 0000000..fafe170 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/door_v0.py @@ -0,0 +1,130 @@ +import numpy as np +from gym import utils +from gym import spaces +from mjrl.envs import mujoco_env +from mujoco_py import MjViewer +from d4rl import offline_env +import os + +ADD_BONUS_REWARDS = True + +class DoorEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): + def __init__(self, **kwargs): + offline_env.OfflineEnv.__init__(self, **kwargs) + self.door_hinge_did = 0 + self.door_bid = 0 + self.grasp_sid = 0 + self.handle_sid = 0 + curr_dir = os.path.dirname(os.path.abspath(__file__)) + mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_door.xml', 5) + + # Override action_space to -1, 1 + self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) + + # change actuator sensitivity + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) + + utils.EzPickle.__init__(self) + ob = self.reset_model() + self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) + self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) + self.door_hinge_did = self.model.jnt_dofadr[self.model.joint_name2id('door_hinge')] + self.grasp_sid = self.model.site_name2id('S_grasp') + self.handle_sid = self.model.site_name2id('S_handle') + self.door_bid = self.model.body_name2id('frame') + + def step(self, a): + a = np.clip(a, -1.0, 1.0) + try: + a = self.act_mid + a*self.act_rng # mean center and scale + except: + a = a # only for the initialization phase + self.do_simulation(a, self.frame_skip) + ob = self.get_obs() + handle_pos = self.data.site_xpos[self.handle_sid].ravel() + palm_pos = self.data.site_xpos[self.grasp_sid].ravel() + door_pos = self.data.qpos[self.door_hinge_did] + + # get to handle + reward = -0.1*np.linalg.norm(palm_pos-handle_pos) + # open door + reward += -0.1*(door_pos - 1.57)*(door_pos - 1.57) + # velocity cost + reward += -1e-5*np.sum(self.data.qvel**2) + + if ADD_BONUS_REWARDS: + # Bonus + if door_pos > 0.2: + reward += 2 + if door_pos > 1.0: + reward += 8 + if door_pos > 1.35: + reward += 10 + + goal_achieved = True if door_pos >= 1.35 else False + + return ob, reward, False, dict(goal_achieved=goal_achieved) + + def get_obs(self): + # qpos for hand + # xpos for obj + # xpos for target + qp = self.data.qpos.ravel() + handle_pos = self.data.site_xpos[self.handle_sid].ravel() + palm_pos = self.data.site_xpos[self.grasp_sid].ravel() + door_pos = np.array([self.data.qpos[self.door_hinge_did]]) + if door_pos > 1.0: + door_open = 1.0 + else: + door_open = -1.0 + latch_pos = qp[-1] + return np.concatenate([qp[1:-2], [latch_pos], door_pos, palm_pos, handle_pos, palm_pos-handle_pos, [door_open]]) + + def reset_model(self): + qp = self.init_qpos.copy() + qv = self.init_qvel.copy() + self.set_state(qp, qv) + + self.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2) + self.model.body_pos[self.door_bid,1] = self.np_random.uniform(low=0.25, high=0.35) + self.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35) + self.sim.forward() + return self.get_obs() + + def get_env_state(self): + """ + Get state of hand as well as objects and targets in the scene + """ + qp = self.data.qpos.ravel().copy() + qv = self.data.qvel.ravel().copy() + door_body_pos = self.model.body_pos[self.door_bid].ravel().copy() + return dict(qpos=qp, qvel=qv, door_body_pos=door_body_pos) + + def set_env_state(self, state_dict): + """ + Set the state which includes hand as well as objects and targets in the scene + """ + qp = state_dict['qpos'] + qv = state_dict['qvel'] + self.set_state(qp, qv) + self.model.body_pos[self.door_bid] = state_dict['door_body_pos'] + self.sim.forward() + + def mj_viewer_setup(self): + self.viewer = MjViewer(self.sim) + self.viewer.cam.azimuth = 90 + self.sim.forward() + self.viewer.cam.distance = 1.5 + + def evaluate_success(self, paths): + num_success = 0 + num_paths = len(paths) + # success if door open for 25 steps + for path in paths: + if np.sum(path['env_infos']['goal_achieved']) > 25: + num_success += 1 + success_percentage = num_success*100.0/num_paths + return success_percentage diff --git a/d4rl/d4rl/hand_manipulation_suite/hammer_v0.py b/d4rl/d4rl/hand_manipulation_suite/hammer_v0.py new file mode 100644 index 0000000..727980a --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/hammer_v0.py @@ -0,0 +1,135 @@ +import numpy as np +from gym import utils +from gym import spaces +from mjrl.envs import mujoco_env +from mujoco_py import MjViewer +from d4rl.utils.quatmath import quat2euler +from d4rl import offline_env +import os + +ADD_BONUS_REWARDS = True + +class HammerEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): + def __init__(self, **kwargs): + offline_env.OfflineEnv.__init__(self, **kwargs) + self.target_obj_sid = -1 + self.S_grasp_sid = -1 + self.obj_bid = -1 + self.tool_sid = -1 + self.goal_sid = -1 + curr_dir = os.path.dirname(os.path.abspath(__file__)) + mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_hammer.xml', 5) + + # Override action_space to -1, 1 + self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) + + utils.EzPickle.__init__(self) + + # change actuator sensitivity + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) + + self.target_obj_sid = self.sim.model.site_name2id('S_target') + self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') + self.obj_bid = self.sim.model.body_name2id('Object') + self.tool_sid = self.sim.model.site_name2id('tool') + self.goal_sid = self.sim.model.site_name2id('nail_goal') + self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) + self.act_rng = 0.5 * (self.model.actuator_ctrlrange[:, 1] - self.model.actuator_ctrlrange[:, 0]) + + def step(self, a): + a = np.clip(a, -1.0, 1.0) + try: + a = self.act_mid + a * self.act_rng # mean center and scale + except: + a = a # only for the initialization phase + self.do_simulation(a, self.frame_skip) + ob = self.get_obs() + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + tool_pos = self.data.site_xpos[self.tool_sid].ravel() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel() + goal_pos = self.data.site_xpos[self.goal_sid].ravel() + + # get to hammer + reward = - 0.1 * np.linalg.norm(palm_pos - obj_pos) + # take hammer head to nail + reward -= np.linalg.norm((tool_pos - target_pos)) + # make nail go inside + reward -= 10 * np.linalg.norm(target_pos - goal_pos) + # velocity penalty + reward -= 1e-2 * np.linalg.norm(self.data.qvel.ravel()) + + if ADD_BONUS_REWARDS: + # bonus for lifting up the hammer + if obj_pos[2] > 0.04 and tool_pos[2] > 0.04: + reward += 2 + + # bonus for hammering the nail + if (np.linalg.norm(target_pos - goal_pos) < 0.020): + reward += 25 + if (np.linalg.norm(target_pos - goal_pos) < 0.010): + reward += 75 + + goal_achieved = True if np.linalg.norm(target_pos - goal_pos) < 0.010 else False + + return ob, reward, False, dict(goal_achieved=goal_achieved) + + def get_obs(self): + # qpos for hand + # xpos for obj + # xpos for target + qp = self.data.qpos.ravel() + qv = np.clip(self.data.qvel.ravel(), -1.0, 1.0) + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + obj_rot = quat2euler(self.data.body_xquat[self.obj_bid].ravel()).ravel() + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel() + nail_impact = np.clip(self.sim.data.sensordata[self.sim.model.sensor_name2id('S_nail')], -1.0, 1.0) + return np.concatenate([qp[:-6], qv[-6:], palm_pos, obj_pos, obj_rot, target_pos, np.array([nail_impact])]) + + def reset_model(self): + self.sim.reset() + target_bid = self.model.body_name2id('nail_board') + self.model.body_pos[target_bid,2] = self.np_random.uniform(low=0.1, high=0.25) + self.sim.forward() + return self.get_obs() + + def get_env_state(self): + """ + Get state of hand as well as objects and targets in the scene + """ + qpos = self.data.qpos.ravel().copy() + qvel = self.data.qvel.ravel().copy() + board_pos = self.model.body_pos[self.model.body_name2id('nail_board')].copy() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel().copy() + return dict(qpos=qpos, qvel=qvel, board_pos=board_pos, target_pos=target_pos) + + def set_env_state(self, state_dict): + """ + Set the state which includes hand as well as objects and targets in the scene + """ + qp = state_dict['qpos'] + qv = state_dict['qvel'] + board_pos = state_dict['board_pos'] + self.set_state(qp, qv) + self.model.body_pos[self.model.body_name2id('nail_board')] = board_pos + self.sim.forward() + + def mj_viewer_setup(self): + self.viewer = MjViewer(self.sim) + self.viewer.cam.azimuth = 45 + self.viewer.cam.distance = 2.0 + self.sim.forward() + + def evaluate_success(self, paths): + num_success = 0 + num_paths = len(paths) + # success if nail insude board for 25 steps + for path in paths: + if np.sum(path['env_infos']['goal_achieved']) > 25: + num_success += 1 + success_percentage = num_success*100.0/num_paths + return success_percentage diff --git a/d4rl/d4rl/hand_manipulation_suite/pen_v0.py b/d4rl/d4rl/hand_manipulation_suite/pen_v0.py new file mode 100644 index 0000000..3abf864 --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/pen_v0.py @@ -0,0 +1,148 @@ +import numpy as np +from gym import utils +from gym import spaces +from mjrl.envs import mujoco_env +from d4rl.utils.quatmath import quat2euler, euler2quat +from d4rl import offline_env +from mujoco_py import MjViewer +import os + +ADD_BONUS_REWARDS = True + +class PenEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): + def __init__(self, **kwargs): + offline_env.OfflineEnv.__init__(self, **kwargs) + self.target_obj_bid = 0 + self.S_grasp_sid = 0 + self.eps_ball_sid = 0 + self.obj_bid = 0 + self.obj_t_sid = 0 + self.obj_b_sid = 0 + self.tar_t_sid = 0 + self.tar_b_sid = 0 + self.pen_length = 1.0 + self.tar_length = 1.0 + + curr_dir = os.path.dirname(os.path.abspath(__file__)) + mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_pen.xml', 5) + + # Override action_space to -1, 1 + self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) + + # change actuator sensitivity + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) + + utils.EzPickle.__init__(self) + self.target_obj_bid = self.sim.model.body_name2id("target") + self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') + self.obj_bid = self.sim.model.body_name2id('Object') + self.eps_ball_sid = self.sim.model.site_name2id('eps_ball') + self.obj_t_sid = self.sim.model.site_name2id('object_top') + self.obj_b_sid = self.sim.model.site_name2id('object_bottom') + self.tar_t_sid = self.sim.model.site_name2id('target_top') + self.tar_b_sid = self.sim.model.site_name2id('target_bottom') + + self.pen_length = np.linalg.norm(self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid]) + self.tar_length = np.linalg.norm(self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid]) + + self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) + self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) + + def step(self, a): + a = np.clip(a, -1.0, 1.0) + try: + starting_up = False + a = self.act_mid + a*self.act_rng # mean center and scale + except: + starting_up = True + a = a # only for the initialization phase + self.do_simulation(a, self.frame_skip) + + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + desired_loc = self.data.site_xpos[self.eps_ball_sid].ravel() + obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length + desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length + + # pos cost + dist = np.linalg.norm(obj_pos-desired_loc) + reward = -dist + # orien cost + orien_similarity = np.dot(obj_orien, desired_orien) + reward += orien_similarity + + if ADD_BONUS_REWARDS: + # bonus for being close to desired orientation + if dist < 0.075 and orien_similarity > 0.9: + reward += 10 + if dist < 0.075 and orien_similarity > 0.95: + reward += 50 + + # penalty for dropping the pen + done = False + if obj_pos[2] < 0.075: + reward -= 5 + done = True if not starting_up else False + + goal_achieved = True if (dist < 0.075 and orien_similarity > 0.95) else False + + return self.get_obs(), reward, done, dict(goal_achieved=goal_achieved) + + def get_obs(self): + qp = self.data.qpos.ravel() + obj_vel = self.data.qvel[-6:].ravel() + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + desired_pos = self.data.site_xpos[self.eps_ball_sid].ravel() + obj_orien = (self.data.site_xpos[self.obj_t_sid] - self.data.site_xpos[self.obj_b_sid])/self.pen_length + desired_orien = (self.data.site_xpos[self.tar_t_sid] - self.data.site_xpos[self.tar_b_sid])/self.tar_length + return np.concatenate([qp[:-6], obj_pos, obj_vel, obj_orien, desired_orien, + obj_pos-desired_pos, obj_orien-desired_orien]) + + def reset_model(self): + qp = self.init_qpos.copy() + qv = self.init_qvel.copy() + self.set_state(qp, qv) + desired_orien = np.zeros(3) + desired_orien[0] = self.np_random.uniform(low=-1, high=1) + desired_orien[1] = self.np_random.uniform(low=-1, high=1) + self.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) + self.sim.forward() + return self.get_obs() + + def get_env_state(self): + """ + Get state of hand as well as objects and targets in the scene + """ + qp = self.data.qpos.ravel().copy() + qv = self.data.qvel.ravel().copy() + desired_orien = self.model.body_quat[self.target_obj_bid].ravel().copy() + return dict(qpos=qp, qvel=qv, desired_orien=desired_orien) + + def set_env_state(self, state_dict): + """ + Set the state which includes hand as well as objects and targets in the scene + """ + qp = state_dict['qpos'] + qv = state_dict['qvel'] + desired_orien = state_dict['desired_orien'] + self.set_state(qp, qv) + self.model.body_quat[self.target_obj_bid] = desired_orien + self.sim.forward() + + def mj_viewer_setup(self): + self.viewer = MjViewer(self.sim) + self.viewer.cam.azimuth = -45 + self.sim.forward() + self.viewer.cam.distance = 1.0 + + def evaluate_success(self, paths): + num_success = 0 + num_paths = len(paths) + # success if pen within 15 degrees of target for 20 steps + for path in paths: + if np.sum(path['env_infos']['goal_achieved']) > 20: + num_success += 1 + success_percentage = num_success*100.0/num_paths + return success_percentage diff --git a/d4rl/d4rl/hand_manipulation_suite/relocate_v0.py b/d4rl/d4rl/hand_manipulation_suite/relocate_v0.py new file mode 100644 index 0000000..305d33c --- /dev/null +++ b/d4rl/d4rl/hand_manipulation_suite/relocate_v0.py @@ -0,0 +1,126 @@ +import numpy as np +from gym import utils +from gym import spaces +from mjrl.envs import mujoco_env +from mujoco_py import MjViewer +from d4rl import offline_env +import os + +ADD_BONUS_REWARDS = True + +class RelocateEnvV0(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): + def __init__(self, **kwargs): + offline_env.OfflineEnv.__init__(self, **kwargs) + self.target_obj_sid = 0 + self.S_grasp_sid = 0 + self.obj_bid = 0 + curr_dir = os.path.dirname(os.path.abspath(__file__)) + mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/DAPG_relocate.xml', 5) + + # Override action_space to -1, 1 + self.action_space = spaces.Box(low=-1.0, high=1.0, dtype=np.float32, shape=self.action_space.shape) + + # change actuator sensitivity + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([10, 0, 0]) + self.sim.model.actuator_gainprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([1, 0, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_WRJ1'):self.sim.model.actuator_name2id('A_WRJ0')+1,:3] = np.array([0, -10, 0]) + self.sim.model.actuator_biasprm[self.sim.model.actuator_name2id('A_FFJ3'):self.sim.model.actuator_name2id('A_THJ0')+1,:3] = np.array([0, -1, 0]) + + self.target_obj_sid = self.sim.model.site_name2id("target") + self.S_grasp_sid = self.sim.model.site_name2id('S_grasp') + self.obj_bid = self.sim.model.body_name2id('Object') + utils.EzPickle.__init__(self) + self.act_mid = np.mean(self.model.actuator_ctrlrange, axis=1) + self.act_rng = 0.5*(self.model.actuator_ctrlrange[:,1]-self.model.actuator_ctrlrange[:,0]) + + def step(self, a): + a = np.clip(a, -1.0, 1.0) + try: + a = self.act_mid + a*self.act_rng # mean center and scale + except: + a = a # only for the initialization phase + self.do_simulation(a, self.frame_skip) + ob = self.get_obs() + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel() + + reward = -0.1*np.linalg.norm(palm_pos-obj_pos) # take hand to object + if obj_pos[2] > 0.04: # if object off the table + reward += 1.0 # bonus for lifting the object + reward += -0.5*np.linalg.norm(palm_pos-target_pos) # make hand go to target + reward += -0.5*np.linalg.norm(obj_pos-target_pos) # make object go to target + + if ADD_BONUS_REWARDS: + if np.linalg.norm(obj_pos-target_pos) < 0.1: + reward += 10.0 # bonus for object close to target + if np.linalg.norm(obj_pos-target_pos) < 0.05: + reward += 20.0 # bonus for object "very" close to target + + goal_achieved = True if np.linalg.norm(obj_pos-target_pos) < 0.1 else False + + return ob, reward, False, dict(goal_achieved=goal_achieved) + + def get_obs(self): + # qpos for hand + # xpos for obj + # xpos for target + qp = self.data.qpos.ravel() + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel() + return np.concatenate([qp[:-6], palm_pos-obj_pos, palm_pos-target_pos, obj_pos-target_pos]) + + def reset_model(self): + qp = self.init_qpos.copy() + qv = self.init_qvel.copy() + self.set_state(qp, qv) + self.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15) + self.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3) + self.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2) + self.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2) + self.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35) + self.sim.forward() + return self.get_obs() + + def get_env_state(self): + """ + Get state of hand as well as objects and targets in the scene + """ + qp = self.data.qpos.ravel().copy() + qv = self.data.qvel.ravel().copy() + hand_qpos = qp[:30] + obj_pos = self.data.body_xpos[self.obj_bid].ravel() + palm_pos = self.data.site_xpos[self.S_grasp_sid].ravel() + target_pos = self.data.site_xpos[self.target_obj_sid].ravel() + return dict(hand_qpos=hand_qpos, obj_pos=obj_pos, target_pos=target_pos, palm_pos=palm_pos, + qpos=qp, qvel=qv) + + def set_env_state(self, state_dict): + """ + Set the state which includes hand as well as objects and targets in the scene + """ + qp = state_dict['qpos'] + qv = state_dict['qvel'] + obj_pos = state_dict['obj_pos'] + target_pos = state_dict['target_pos'] + self.set_state(qp, qv) + self.model.body_pos[self.obj_bid] = obj_pos + self.model.site_pos[self.target_obj_sid] = target_pos + self.sim.forward() + + def mj_viewer_setup(self): + self.viewer = MjViewer(self.sim) + self.viewer.cam.azimuth = 90 + self.sim.forward() + self.viewer.cam.distance = 1.5 + + def evaluate_success(self, paths): + num_success = 0 + num_paths = len(paths) + # success if object close to target for 25 steps + for path in paths: + if np.sum(path['env_infos']['goal_achieved']) > 25: + num_success += 1 + success_percentage = num_success*100.0/num_paths + return success_percentage diff --git a/d4rl/d4rl/infos.py b/d4rl/d4rl/infos.py new file mode 100644 index 0000000..fb9256c --- /dev/null +++ b/d4rl/d4rl/infos.py @@ -0,0 +1,311 @@ +""" +This file holds all URLs and reference scores. +""" + +#TODO(Justin): This is duplicated. Make all __init__ file URLs and scores point to this file. + +DATASET_URLS = { + 'maze2d-open-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5', + 'maze2d-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5', + 'maze2d-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5', + 'maze2d-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5', + 'maze2d-eval-umaze-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5', + 'maze2d-eval-medium-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5', + 'maze2d-eval-large-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5', + 'maze2d-open-dense-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5', + 'maze2d-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5', + 'maze2d-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5', + 'maze2d-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5', + 'maze2d-eval-umaze-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5', + 'maze2d-eval-medium-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5', + 'maze2d-eval-large-dense-v1' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5', + 'minigrid-fourrooms-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5', + 'minigrid-fourrooms-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5', + 'pen-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5', + 'pen-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5', + 'pen-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5', + 'hammer-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5', + 'hammer-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5', + 'hammer-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5', + 'relocate-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5', + 'relocate-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5', + 'relocate-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5', + 'door-human-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5', + 'door-cloned-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5', + 'door-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5', + 'halfcheetah-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5', + 'halfcheetah-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5', + 'halfcheetah-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5', + 'halfcheetah-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5', + 'halfcheetah-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5', + 'walker2d-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5', + 'walker2d-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5', + 'walker2d-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5', + 'walker2d-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5', + 'walker2d-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5', + 'hopper-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5', + 'hopper-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5', + 'hopper-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5', + 'hopper-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5', + 'hopper-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5', + 'ant-random-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5', + 'ant-medium-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5', + 'ant-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5', + 'ant-medium-replay-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5', + 'ant-medium-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5', + 'ant-random-expert-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5', + 'antmaze-umaze-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', + 'antmaze-umaze-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'antmaze-medium-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'antmaze-medium-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'antmaze-large-play-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'antmaze-large-diverse-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'antmaze-umaze-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5', + 'antmaze-umaze-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'antmaze-medium-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', + 'antmaze-medium-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'antmaze-large-play-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', + 'antmaze-large-diverse-v2' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'flow-ring-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5', + 'flow-ring-controller-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5', + 'flow-merge-random-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5', + 'flow-merge-controller-v0':'http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5', + 'kitchen-complete-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5', + 'kitchen-partial-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5', + 'kitchen-mixed-v0' : 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5', + 'carla-lane-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5', + 'carla-town-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5', + 'carla-town-full-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5', + 'bullet-halfcheetah-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5', + 'bullet-halfcheetah-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5', + 'bullet-halfcheetah-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5', + 'bullet-halfcheetah-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5', + 'bullet-halfcheetah-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5', + 'bullet-hopper-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5', + 'bullet-hopper-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5', + 'bullet-hopper-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5', + 'bullet-hopper-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5', + 'bullet-hopper-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5', + 'bullet-ant-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5', + 'bullet-ant-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5', + 'bullet-ant-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5', + 'bullet-ant-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5', + 'bullet-ant-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5', + 'bullet-walker2d-random-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5', + 'bullet-walker2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5', + 'bullet-walker2d-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5', + 'bullet-walker2d-medium-expert-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5', + 'bullet-walker2d-medium-replay-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5', + 'bullet-maze2d-open-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5', + 'bullet-maze2d-umaze-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5', + 'bullet-maze2d-medium-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5', + 'bullet-maze2d-large-v0': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5', +} + + +REF_MIN_SCORE = { + 'maze2d-open-v0' : 0.01 , + 'maze2d-umaze-v1' : 23.85 , + 'maze2d-medium-v1' : 13.13 , + 'maze2d-large-v1' : 6.7 , + 'maze2d-open-dense-v0' : 11.17817 , + 'maze2d-umaze-dense-v1' : 68.537689 , + 'maze2d-medium-dense-v1' : 44.264742 , + 'maze2d-large-dense-v1' : 30.569041 , + 'minigrid-fourrooms-v0' : 0.01442 , + 'minigrid-fourrooms-random-v0' : 0.01442 , + 'pen-human-v0' : 96.262799 , + 'pen-cloned-v0' : 96.262799 , + 'pen-expert-v0' : 96.262799 , + 'hammer-human-v0' : -274.856578 , + 'hammer-cloned-v0' : -274.856578 , + 'hammer-expert-v0' : -274.856578 , + 'relocate-human-v0' : -6.425911 , + 'relocate-cloned-v0' : -6.425911 , + 'relocate-expert-v0' : -6.425911 , + 'door-human-v0' : -56.512833 , + 'door-cloned-v0' : -56.512833 , + 'door-expert-v0' : -56.512833 , + 'halfcheetah-random-v0' : -280.178953 , + 'halfcheetah-medium-v0' : -280.178953 , + 'halfcheetah-expert-v0' : -280.178953 , + 'halfcheetah-medium-replay-v0' : -280.178953 , + 'halfcheetah-medium-expert-v0' : -280.178953 , + 'walker2d-random-v0' : 1.629008 , + 'walker2d-medium-v0' : 1.629008 , + 'walker2d-expert-v0' : 1.629008 , + 'walker2d-medium-replay-v0' : 1.629008 , + 'walker2d-medium-expert-v0' : 1.629008 , + 'hopper-random-v0' : -20.272305 , + 'hopper-medium-v0' : -20.272305 , + 'hopper-expert-v0' : -20.272305 , + 'hopper-medium-replay-v0' : -20.272305 , + 'hopper-medium-expert-v0' : -20.272305 , + 'ant-random-v0' : -325.6, + 'ant-medium-v0' : -325.6, + 'ant-expert-v0' : -325.6, + 'ant-medium-replay-v0' : -325.6, + 'ant-medium-expert-v0' : -325.6, + 'antmaze-umaze-v0' : 0.0 , + 'antmaze-umaze-diverse-v0' : 0.0 , + 'antmaze-medium-play-v0' : 0.0 , + 'antmaze-medium-diverse-v0' : 0.0 , + 'antmaze-large-play-v0' : 0.0 , + 'antmaze-large-diverse-v0' : 0.0 , + 'antmaze-umaze-v2' : 0.0 , + 'antmaze-umaze-diverse-v2' : 0.0 , + 'antmaze-medium-play-v2' : 0.0 , + 'antmaze-medium-diverse-v2' : 0.0 , + 'antmaze-large-play-v2' : 0.0 , + 'antmaze-large-diverse-v2' : 0.0 , + 'kitchen-complete-v0' : 0.0 , + 'kitchen-partial-v0' : 0.0 , + 'kitchen-mixed-v0' : 0.0 , + 'flow-ring-random-v0' : -165.22 , + 'flow-ring-controller-v0' : -165.22 , + 'flow-merge-random-v0' : 118.67993 , + 'flow-merge-controller-v0' : 118.67993 , + 'carla-lane-v0': -0.8503839912088142, + 'carla-town-v0': -114.81579500772153, # random score + 'bullet-halfcheetah-random-v0': -1275.766996, + 'bullet-halfcheetah-medium-v0': -1275.766996, + 'bullet-halfcheetah-expert-v0': -1275.766996, + 'bullet-halfcheetah-medium-expert-v0': -1275.766996, + 'bullet-halfcheetah-medium-replay-v0': -1275.766996, + 'bullet-hopper-random-v0': 20.058972, + 'bullet-hopper-medium-v0': 20.058972, + 'bullet-hopper-expert-v0': 20.058972, + 'bullet-hopper-medium-expert-v0': 20.058972, + 'bullet-hopper-medium-replay-v0': 20.058972, + 'bullet-ant-random-v0': 373.705955, + 'bullet-ant-medium-v0': 373.705955, + 'bullet-ant-expert-v0': 373.705955, + 'bullet-ant-medium-expert-v0': 373.705955, + 'bullet-ant-medium-replay-v0': 373.705955, + 'bullet-walker2d-random-v0': 16.523877, + 'bullet-walker2d-medium-v0': 16.523877, + 'bullet-walker2d-expert-v0': 16.523877, + 'bullet-walker2d-medium-expert-v0': 16.523877, + 'bullet-walker2d-medium-replay-v0': 16.523877, + 'bullet-maze2d-open-v0': 8.750000, + 'bullet-maze2d-umaze-v0': 32.460000, + 'bullet-maze2d-medium-v0': 14.870000, + 'bullet-maze2d-large-v0': 1.820000, +} + +REF_MAX_SCORE = { + 'maze2d-open-v0' : 20.66 , + 'maze2d-umaze-v1' : 161.86 , + 'maze2d-medium-v1' : 277.39 , + 'maze2d-large-v1' : 273.99 , + 'maze2d-open-dense-v0' : 27.166538620695782 , + 'maze2d-umaze-dense-v1' : 193.66285642381482 , + 'maze2d-medium-dense-v1' : 297.4552547777125 , + 'maze2d-large-dense-v1' : 303.4857382709002 , + 'minigrid-fourrooms-v0' : 2.89685 , + 'minigrid-fourrooms-random-v0' : 2.89685 , + 'pen-human-v0' : 3076.8331017826877 , + 'pen-cloned-v0' : 3076.8331017826877 , + 'pen-expert-v0' : 3076.8331017826877 , + 'hammer-human-v0' : 12794.134825156867 , + 'hammer-cloned-v0' : 12794.134825156867 , + 'hammer-expert-v0' : 12794.134825156867 , + 'relocate-human-v0' : 4233.877797728884 , + 'relocate-cloned-v0' : 4233.877797728884 , + 'relocate-expert-v0' : 4233.877797728884 , + 'door-human-v0' : 2880.5693087298737 , + 'door-cloned-v0' : 2880.5693087298737 , + 'door-expert-v0' : 2880.5693087298737 , + 'halfcheetah-random-v0' : 12135.0 , + 'halfcheetah-medium-v0' : 12135.0 , + 'halfcheetah-expert-v0' : 12135.0 , + 'halfcheetah-medium-replay-v0' : 12135.0 , + 'halfcheetah-medium-expert-v0' : 12135.0 , + 'walker2d-random-v0' : 4592.3 , + 'walker2d-medium-v0' : 4592.3 , + 'walker2d-expert-v0' : 4592.3 , + 'walker2d-medium-replay-v0' : 4592.3 , + 'walker2d-medium-expert-v0' : 4592.3 , + 'hopper-random-v0' : 3234.3 , + 'hopper-medium-v0' : 3234.3 , + 'hopper-expert-v0' : 3234.3 , + 'hopper-medium-replay-v0' : 3234.3 , + 'hopper-medium-expert-v0' : 3234.3 , + 'ant-random-v0' : 3879.7, + 'ant-medium-v0' : 3879.7, + 'ant-expert-v0' : 3879.7, + 'ant-medium-replay-v0' : 3879.7, + 'ant-medium-expert-v0' : 3879.7, + 'antmaze-umaze-v0' : 1.0 , + 'antmaze-umaze-diverse-v0' : 1.0 , + 'antmaze-medium-play-v0' : 1.0 , + 'antmaze-medium-diverse-v0' : 1.0 , + 'antmaze-large-play-v0' : 1.0 , + 'antmaze-large-diverse-v0' : 1.0 , + 'antmaze-umaze-v2' : 1.0 , + 'antmaze-umaze-diverse-v2' : 1.0 , + 'antmaze-medium-play-v2' : 1.0 , + 'antmaze-medium-diverse-v2' : 1.0 , + 'antmaze-large-play-v2' : 1.0 , + 'antmaze-large-diverse-v2' : 1.0 , + 'kitchen-complete-v0' : 4.0 , + 'kitchen-partial-v0' : 4.0 , + 'kitchen-mixed-v0' : 4.0 , + 'flow-ring-random-v0' : 24.42 , + 'flow-ring-controller-v0' : 24.42 , + 'flow-merge-random-v0' : 330.03179 , + 'flow-merge-controller-v0' : 330.03179 , + 'carla-lane-v0': 1023.5784385429523, + 'carla-town-v0': 2440.1772022247314, # avg dataset score + 'bullet-halfcheetah-random-v0': 2381.6725, + 'bullet-halfcheetah-medium-v0': 2381.6725, + 'bullet-halfcheetah-expert-v0': 2381.6725, + 'bullet-halfcheetah-medium-expert-v0': 2381.6725, + 'bullet-halfcheetah-medium-replay-v0': 2381.6725, + 'bullet-hopper-random-v0': 1441.8059623430963, + 'bullet-hopper-medium-v0': 1441.8059623430963, + 'bullet-hopper-expert-v0': 1441.8059623430963, + 'bullet-hopper-medium-expert-v0': 1441.8059623430963, + 'bullet-hopper-medium-replay-v0': 1441.8059623430963, + 'bullet-ant-random-v0': 2650.495, + 'bullet-ant-medium-v0': 2650.495, + 'bullet-ant-expert-v0': 2650.495, + 'bullet-ant-medium-expert-v0': 2650.495, + 'bullet-ant-medium-replay-v0': 2650.495, + 'bullet-walker2d-random-v0': 1623.6476303317536, + 'bullet-walker2d-medium-v0': 1623.6476303317536, + 'bullet-walker2d-expert-v0': 1623.6476303317536, + 'bullet-walker2d-medium-expert-v0': 1623.6476303317536, + 'bullet-walker2d-medium-replay-v0': 1623.6476303317536, + 'bullet-maze2d-open-v0': 64.15, + 'bullet-maze2d-umaze-v0': 153.99, + 'bullet-maze2d-medium-v0': 238.05, + 'bullet-maze2d-large-v0': 285.92, +} + + +#Gym-MuJoCo V1/V2 envs +for env in ['halfcheetah', 'hopper', 'walker2d', 'ant']: + for dset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']: + #v1 envs + dset_name = env+'_'+dset.replace('-', '_')+'-v1' + env_name = dset_name.replace('_', '-') + DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v1/%s.hdf5' % dset_name + REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0'] + REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0'] + + #v2 envs + dset_name = env+'_'+dset.replace('-', '_')+'-v2' + env_name = dset_name.replace('_', '-') + DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/%s.hdf5' % dset_name + REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-random-v0'] + REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-random-v0'] + +#Adroit v1 envs +for env in ['hammer', 'pen', 'relocate', 'door']: + for dset in ['human', 'expert', 'cloned']: + env_name = env+'-'+dset+'-v1' + DATASET_URLS[env_name] = 'http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg_v1/%s.hdf5' % env_name + REF_MIN_SCORE[env_name] = REF_MIN_SCORE[env+'-human-v0'] + REF_MAX_SCORE[env_name] = REF_MAX_SCORE[env+'-human-v0'] + diff --git a/d4rl/d4rl/kitchen/__init__.py b/d4rl/d4rl/kitchen/__init__.py new file mode 100644 index 0000000..a769720 --- /dev/null +++ b/d4rl/d4rl/kitchen/__init__.py @@ -0,0 +1,41 @@ +from .kitchen_envs import KitchenMicrowaveKettleLightSliderV0, KitchenMicrowaveKettleBottomBurnerLightV0 +from gym.envs.registration import register + +# Smaller dataset with only positive demonstrations. +register( + id='kitchen-complete-v0', + entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0', + max_episode_steps=280, + kwargs={ + 'ref_min_score': 0.0, + 'ref_max_score': 4.0, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5' + } +) + +# Whole dataset with undirected demonstrations. A subset of the demonstrations +# solve the task. +register( + id='kitchen-partial-v0', + entry_point='d4rl.kitchen:KitchenMicrowaveKettleLightSliderV0', + max_episode_steps=280, + kwargs={ + 'ref_min_score': 0.0, + 'ref_max_score': 4.0, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5' + } +) + +# Whole dataset with undirected demonstrations. No demonstration completely +# solves the task, but each demonstration partially solves different +# components of the task. +register( + id='kitchen-mixed-v0', + entry_point='d4rl.kitchen:KitchenMicrowaveKettleBottomBurnerLightV0', + max_episode_steps=280, + kwargs={ + 'ref_min_score': 0.0, + 'ref_max_score': 4.0, + 'dataset_url': 'http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5' + } +) diff --git a/d4rl/d4rl/kitchen/adept_envs/.pylintrc b/d4rl/d4rl/kitchen/adept_envs/.pylintrc new file mode 100644 index 0000000..9cda412 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/.pylintrc @@ -0,0 +1,433 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=relative-beyond-top-level + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, while `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=80 + +# Maximum number of lines in a module +max-module-lines=99999 + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma, + dict-separator + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[BASIC] + +# Naming style matching correct argument names +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style +argument-rgx=^[a-z][a-z0-9_]*$ + +# Naming style matching correct attribute names +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Naming style matching correct class attribute names +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Naming style matching correct class names +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming-style +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Naming style matching correct constant names +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + +# Naming style matching correct function names +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main, + _ + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Naming style matching correct inline iteration names +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Naming style matching correct method names +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Naming style matching correct module names +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style +module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$ + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group=function:method + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,google3.pyglib.function_utils.cached.property + +# Naming style matching correct variable names +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style +variable-rgx=^[a-z][a-z0-9_]*$ + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception". +overgeneral-exceptions=Exception diff --git a/d4rl/d4rl/kitchen/adept_envs/.style.yapf b/d4rl/d4rl/kitchen/adept_envs/.style.yapf new file mode 100644 index 0000000..29f83ff --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/.style.yapf @@ -0,0 +1,323 @@ +[style] +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent=False + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys=False + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas=False + +# Allow splitting before a default / named assignment in an argument list. +allow_split_before_default_or_named_assigns=True + +# Allow splits before the dictionary value. +allow_split_before_dict_value=True + +# Let spacing indicate operator precedence. For example: +# +# a = 1 * 2 + 3 / 4 +# b = 1 / 2 - 3 * 4 +# c = (1 + 2) * (3 - 4) +# d = (1 - 2) / (3 + 4) +# e = 1 * 2 - 3 +# f = 1 + 2 + 3 + 4 +# +# will be formatted as follows to indicate precedence: +# +# a = 1*2 + 3/4 +# b = 1/2 - 3*4 +# c = (1+2) * (3-4) +# d = (1-2) / (3+4) +# e = 1*2 - 3 +# f = 1 + 2 + 3 + 4 +# +arithmetic_precedence_indication=False + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition=2 + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring=False + +# Insert a blank line before a module docstring. +blank_line_before_module_docstring=False + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def=True + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets=False + +# The column limit. +column_limit=80 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs) for continuation +# alignment. +# - LESS: Slightly left if cannot vertically align continuation lines with +# indent characters. +# - VALIGN-RIGHT: Vertically align continuation lines with indent +# characters. Slightly right (one more indent character) if cannot +# vertically align continuation lines with indent characters. +# +# For options FIXED, and VALIGN-RIGHT are only available when USE_TABS is +# enabled. +continuation_align_style=SPACE + +# Indent width used for line continuations. +continuation_indent_width=4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets=False + +# Disable the heuristic which places each list element on a separate line +# if the list is comma-terminated. +disable_ending_comma_heuristic=False + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line=True + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment=#\..* + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call=N_, _ + +# Indent blank lines. +indent_blank_lines=False + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value=False + +# The number of columns to use for indentation. +indent_width=4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines=True + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with "*,/": +# +# 1 + 2*3 - 4/5 +# +no_spaces_around_selected_binary_operators= + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign=False + +# Use spaces around the power operator. +spaces_around_power_operator=False + +# The number of spaces required before a trailing comment. +# This can be a single value (representing the number of spaces +# before each trailing comment) or list of values (representing +# alignment column values; trailing comments within a block will +# be aligned to the first column value that is greater than the maximum +# line length within the block). For example: +# +# With spaces_before_comment=5: +# +# 1 + 1 # Adding values +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment +# +# With spaces_before_comment=15, 20: +# +# 1 + 1 # Adding values +# two + two # More adding +# +# longer_statement # This is a longer statement +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment +# short # This is a shorter statement +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 +# two + two # More adding +# +# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length +# short # This is a shorter statement +# +spaces_before_comment=2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket=False + +# Split before arguments +split_all_comma_separated_values=False + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated=False + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator=False + +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket=True + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator=False + +# Split before the '.' if we need to split a longer expression: +# +# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) +# +# would reformat to something like: +# +# foo = ('This is a really long string: {}, {}, {}, {}' +# .format(a, b, c, d)) +split_before_dot=False + +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren=False + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument=False + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator=False + +# Split named assignments onto individual lines. +split_before_named_assigns=True + +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension=True + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket=30 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator=10000 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr=0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator=300 + +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension=2100 + +# The penalty for characters over the column limit. +split_penalty_excess_character=7000 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split=30 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names=0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator=300 + +# Use the Tab character for indentation. +use_tabs=False + diff --git a/d4rl/d4rl/kitchen/adept_envs/__init__.py b/d4rl/d4rl/kitchen/adept_envs/__init__.py new file mode 100644 index 0000000..20e809d --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import d4rl.kitchen.adept_envs.franka + +from d4rl.kitchen.adept_envs.utils.configurable import global_config diff --git a/d4rl/d4rl/kitchen/adept_envs/base_robot.py b/d4rl/d4rl/kitchen/adept_envs/base_robot.py new file mode 100644 index 0000000..5c6f30f --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/base_robot.py @@ -0,0 +1,151 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from collections import deque + +class BaseRobot(object): + """Base class for all robot classes.""" + + def __init__(self, + n_jnt, + n_obj, + pos_bounds=None, + vel_bounds=None, + calibration_path=None, + is_hardware=False, + device_name=None, + overlay=False, + calibration_mode=False, + observation_cache_maxsize=5): + """Create a new robot. + Args: + n_jnt: The number of dofs in the robot. + n_obj: The number of dofs in the object. + pos_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint + position for each joint. + vel_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint + velocity for each joint. + calibration_path: File path to the calibration configuration file to + use. + is_hardware: Whether to run on hardware or not. + device_name: The device path for the robot hardware. Only required + in legacy mode. + overlay: Whether to show a simulation overlay of the hardware. + calibration_mode: Start with motors disengaged. + """ + + assert n_jnt > 0 + assert n_obj >= 0 + + self._n_jnt = n_jnt + self._n_obj = n_obj + self._n_dofs = n_jnt + n_obj + + self._pos_bounds = None + if pos_bounds is not None: + pos_bounds = np.array(pos_bounds, dtype=np.float32) + assert pos_bounds.shape == (self._n_dofs, 2) + for low, high in pos_bounds: + assert low < high + self._pos_bounds = pos_bounds + self._vel_bounds = None + if vel_bounds is not None: + vel_bounds = np.array(vel_bounds, dtype=np.float32) + assert vel_bounds.shape == (self._n_dofs, 2) + for low, high in vel_bounds: + assert low < high + self._vel_bounds = vel_bounds + + self._is_hardware = is_hardware + self._device_name = device_name + self._calibration_path = calibration_path + self._overlay = overlay + self._calibration_mode = calibration_mode + self._observation_cache_maxsize = observation_cache_maxsize + + # Gets updated + self._observation_cache = deque([], maxlen=self._observation_cache_maxsize) + + + @property + def n_jnt(self): + return self._n_jnt + + @property + def n_obj(self): + return self._n_obj + + @property + def n_dofs(self): + return self._n_dofs + + @property + def pos_bounds(self): + return self._pos_bounds + + @property + def vel_bounds(self): + return self._vel_bounds + + @property + def is_hardware(self): + return self._is_hardware + + @property + def device_name(self): + return self._device_name + + @property + def calibration_path(self): + return self._calibration_path + + @property + def overlay(self): + return self._overlay + + @property + def has_obj(self): + return self._n_obj > 0 + + @property + def calibration_mode(self): + return self._calibration_mode + + @property + def observation_cache_maxsize(self): + return self._observation_cache_maxsize + + @property + def observation_cache(self): + return self._observation_cache + + + def clip_positions(self, positions): + """Clips the given joint positions to the position bounds. + + Args: + positions: The joint positions. + + Returns: + The bounded joint positions. + """ + if self.pos_bounds is None: + return positions + assert len(positions) == self.n_jnt or len(positions) == self.n_dofs + pos_bounds = self.pos_bounds[:len(positions)] + return np.clip(positions, pos_bounds[:, 0], pos_bounds[:, 1]) + diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/__init__.py b/d4rl/d4rl/kitchen/adept_envs/franka/__init__.py new file mode 100644 index 0000000..528f344 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/franka/__init__.py @@ -0,0 +1,24 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gym.envs.registration import register + +# Relax the robot +register( + id='kitchen_relax-v1', + entry_point='adept_envs.franka.kitchen_multitask_v0:KitchenTaskRelaxV1', + max_episode_steps=280, +) \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml b/d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml new file mode 100644 index 0000000..dd8b6c5 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml @@ -0,0 +1,94 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py b/d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py new file mode 100644 index 0000000..eb7ba51 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/franka/kitchen_multitask_v0.py @@ -0,0 +1,198 @@ +""" Kitchen environment for long horizon manipulation """ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from d4rl.kitchen.adept_envs import robot_env +from d4rl.kitchen.adept_envs.utils.configurable import configurable +from gym import spaces +from dm_control.mujoco import engine + +@configurable(pickleable=True) +class KitchenV0(robot_env.RobotEnv): + + CALIBRATION_PATHS = { + 'default': + os.path.join(os.path.dirname(__file__), 'robot/franka_config.xml') + } + # Converted to velocity actuation + ROBOTS = {'robot': 'd4rl.kitchen.adept_envs.franka.robot.franka_robot:Robot_VelAct'} + MODEl = os.path.join( + os.path.dirname(__file__), + '../franka/assets/franka_kitchen_jntpos_act_ab.xml') + N_DOF_ROBOT = 9 + N_DOF_OBJECT = 21 + + def __init__(self, robot_params={}, frame_skip=40): + self.goal_concat = True + self.obs_dict = {} + self.robot_noise_ratio = 0.1 # 10% as per robot_config specs + self.goal = np.zeros((30,)) + + super().__init__( + self.MODEl, + robot=self.make_robot( + n_jnt=self.N_DOF_ROBOT, #root+robot_jnts + n_obj=self.N_DOF_OBJECT, + **robot_params), + frame_skip=frame_skip, + camera_settings=dict( + distance=4.5, + azimuth=-66, + elevation=-65, + ), + ) + self.init_qpos = self.sim.model.key_qpos[0].copy() + + # For the microwave kettle slide hinge + self.init_qpos = np.array([ 1.48388023e-01, -1.76848573e+00, 1.84390296e+00, -2.47685760e+00, + 2.60252026e-01, 7.12533105e-01, 1.59515394e+00, 4.79267505e-02, + 3.71350919e-02, -2.66279850e-04, -5.18043486e-05, 3.12877220e-05, + -4.51199853e-05, -3.90842156e-06, -4.22629655e-05, 6.28065475e-05, + 4.04984708e-05, 4.62730939e-04, -2.26906415e-04, -4.65501369e-04, + -6.44129196e-03, -1.77048263e-03, 1.08009684e-03, -2.69397440e-01, + 3.50383255e-01, 1.61944683e+00, 1.00618764e+00, 4.06395120e-03, + -6.62095997e-03, -2.68278933e-04]) + + self.init_qvel = self.sim.model.key_qvel[0].copy() + + self.act_mid = np.zeros(self.N_DOF_ROBOT) + self.act_amp = 2.0 * np.ones(self.N_DOF_ROBOT) + + act_lower = -1*np.ones((self.N_DOF_ROBOT,)) + act_upper = 1*np.ones((self.N_DOF_ROBOT,)) + self.action_space = spaces.Box(act_lower, act_upper) + + obs_upper = 8. * np.ones(self.obs_dim) + obs_lower = -obs_upper + self.observation_space = spaces.Box(obs_lower, obs_upper) + + def _get_reward_n_score(self, obs_dict): + raise NotImplementedError() + + def step(self, a, b=None): + a = np.clip(a, -1.0, 1.0) + + if not self.initializing: + a = self.act_mid + a * self.act_amp # mean center and scale + else: + self.goal = self._get_task_goal() # update goal if init + + self.robot.step( + self, a, step_duration=self.skip * self.model.opt.timestep) + + # observations + obs = self._get_obs() + + #rewards + reward_dict, score = self._get_reward_n_score(self.obs_dict) + + # termination + done = False + + # finalize step + env_info = { + 'time': self.obs_dict['t'], + 'obs_dict': self.obs_dict, + 'rewards': reward_dict, + 'score': score, + 'images': np.asarray(self.render(mode='rgb_array')) + } + # self.render() + return obs, reward_dict['r_total'], done, env_info + + def _get_obs(self): + t, qp, qv, obj_qp, obj_qv = self.robot.get_obs( + self, robot_noise_ratio=self.robot_noise_ratio) + + self.obs_dict = {} + self.obs_dict['t'] = t + self.obs_dict['qp'] = qp + self.obs_dict['qv'] = qv + self.obs_dict['obj_qp'] = obj_qp + self.obs_dict['obj_qv'] = obj_qv + self.obs_dict['goal'] = self.goal + if self.goal_concat: + return np.concatenate([self.obs_dict['qp'], self.obs_dict['obj_qp'], self.obs_dict['goal']]) + + def reset_model(self): + reset_pos = self.init_qpos[:].copy() + reset_vel = self.init_qvel[:].copy() + self.robot.reset(self, reset_pos, reset_vel) + self.sim.forward() + self.goal = self._get_task_goal() #sample a new goal on reset + return self._get_obs() + + def evaluate_success(self, paths): + # score + mean_score_per_rollout = np.zeros(shape=len(paths)) + for idx, path in enumerate(paths): + mean_score_per_rollout[idx] = np.mean(path['env_infos']['score']) + mean_score = np.mean(mean_score_per_rollout) + + # success percentage + num_success = 0 + num_paths = len(paths) + for path in paths: + num_success += bool(path['env_infos']['rewards']['bonus'][-1]) + success_percentage = num_success * 100.0 / num_paths + + # fuse results + return np.sign(mean_score) * ( + 1e6 * round(success_percentage, 2) + abs(mean_score)) + + def close_env(self): + self.robot.close() + + def set_goal(self, goal): + self.goal = goal + + def _get_task_goal(self): + return self.goal + + # Only include goal + @property + def goal_space(self): + len_obs = self.observation_space.low.shape[0] + env_lim = np.abs(self.observation_space.low[0]) + return spaces.Box(low=-env_lim, high=env_lim, shape=(len_obs//2,)) + + def convert_to_active_observation(self, observation): + return observation + +class KitchenTaskRelaxV1(KitchenV0): + """Kitchen environment with proper camera and goal setup""" + + def __init__(self): + super(KitchenTaskRelaxV1, self).__init__() + + def _get_reward_n_score(self, obs_dict): + reward_dict = {} + reward_dict['true_reward'] = 0. + reward_dict['bonus'] = 0. + reward_dict['r_total'] = 0. + score = 0. + return reward_dict, score + + def render(self, mode='human'): + if mode =='rgb_array': + camera = engine.MovableCamera(self.sim, 1920, 2560) + camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35) + img = camera.render() + return img + else: + super(KitchenTaskRelaxV1, self).render() diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py b/d4rl/d4rl/kitchen/adept_envs/franka/robot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml b/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml new file mode 100644 index 0000000..aeb4f49 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_config.xml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py b/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py new file mode 100644 index 0000000..4bbe868 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/franka/robot/franka_robot.py @@ -0,0 +1,265 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, getpass +import numpy as np +from termcolor import cprint +import time +import copy +import click + +from d4rl.kitchen.adept_envs import base_robot +from d4rl.kitchen.adept_envs.utils.config import (get_config_root_node, read_config_from_node) + +# obervations structure +from collections import namedtuple +observation = namedtuple('observation', ['time', 'qpos_robot', 'qvel_robot', 'qpos_object', 'qvel_object']) + + + +franka_interface = '' + +class Robot(base_robot.BaseRobot): + + """ + Abstracts away the differences between the robot_simulation and robot_hardware + + """ + + def __init__(self, *args, **kwargs): + super(Robot, self).__init__(*args, **kwargs) + global franka_interface + + # Read robot configurations + self._read_specs_from_config(robot_configs=self.calibration_path) + + + # Robot: Handware + if self.is_hardware: + + if franka_interface is '': + raise NotImplementedError() + from handware.franka import franka + + # initialize franka + self.franka_interface = franka() + franka_interface = self.franka_interface + cprint("Initializing %s Hardware (Status:%d)" % (self.robot_name, self.franka.okay(self.robot_hardware_dof)), 'white', 'on_grey') + else: + self.franka_interface = franka_interface + cprint("Reusing previours Franka session", 'white', 'on_grey') + + # Robot: Simulation + else: + self.robot_name = "Franka" + cprint("Initializing %s sim" % self.robot_name, 'white', 'on_grey') + + # Robot's time + self.time_start = time.time() + self.time = time.time()-self.time_start + self.time_render = -1 # time of rendering + + + # read specs from the calibration file + def _read_specs_from_config(self, robot_configs): + root, root_name = get_config_root_node(config_file_name=robot_configs) + self.robot_name = root_name[0] + self.robot_mode = np.zeros(self.n_dofs, dtype=int) + self.robot_mj_dof = np.zeros(self.n_dofs, dtype=int) + self.robot_hardware_dof = np.zeros(self.n_dofs, dtype=int) + self.robot_scale = np.zeros(self.n_dofs, dtype=float) + self.robot_offset = np.zeros(self.n_dofs, dtype=float) + self.robot_pos_bound = np.zeros([self.n_dofs, 2], dtype=float) + self.robot_vel_bound = np.zeros([self.n_dofs, 2], dtype=float) + self.robot_pos_noise_amp = np.zeros(self.n_dofs, dtype=float) + self.robot_vel_noise_amp = np.zeros(self.n_dofs, dtype=float) + + print("Reading configurations for %s" % self.robot_name) + for i in range(self.n_dofs): + self.robot_mode[i] = read_config_from_node(root, "qpos"+str(i), "mode", int) + self.robot_mj_dof[i] = read_config_from_node(root, "qpos"+str(i), "mj_dof", int) + self.robot_hardware_dof[i] = read_config_from_node(root, "qpos"+str(i), "hardware_dof", int) + self.robot_scale[i] = read_config_from_node(root, "qpos"+str(i), "scale", float) + self.robot_offset[i] = read_config_from_node(root, "qpos"+str(i), "offset", float) + self.robot_pos_bound[i] = read_config_from_node(root, "qpos"+str(i), "pos_bound", float) + self.robot_vel_bound[i] = read_config_from_node(root, "qpos"+str(i), "vel_bound", float) + self.robot_pos_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "pos_noise_amp", float) + self.robot_vel_noise_amp[i] = read_config_from_node(root, "qpos"+str(i), "vel_noise_amp", float) + + + # convert to hardware space + def _de_calib(self, qp_mj, qv_mj=None): + qp_ad = (qp_mj-self.robot_offset)/self.robot_scale + if qv_mj is not None: + qv_ad = qv_mj/self.robot_scale + return qp_ad, qv_ad + else: + return qp_ad + + # convert to mujoco space + def _calib(self, qp_ad, qv_ad): + qp_mj = qp_ad* self.robot_scale + self.robot_offset + qv_mj = qv_ad* self.robot_scale + return qp_mj, qv_mj + + + # refresh the observation cache + def _observation_cache_refresh(self, env): + for _ in range(self.observation_cache_maxsize): + self.get_obs(env, sim_mimic_hardware=False) + + # get past observation + def get_obs_from_cache(self, env, index=-1): + assert (index>=0 and index=-self.observation_cache_maxsize), \ + "cache index out of bound. (cache size is %2d)"%self.observation_cache_maxsize + obs = self.observation_cache[index] + if self.has_obj: + return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object + else: + return obs.time, obs.qpos_robot, obs.qvel_robot + + + # get observation + def get_obs(self, env, robot_noise_ratio=1, object_noise_ratio=1, sim_mimic_hardware=True): + if self.is_hardware: + raise NotImplementedError() + + else: + #Gather simulated observation + qp = env.sim.data.qpos[:self.n_jnt].copy() + qv = env.sim.data.qvel[:self.n_jnt].copy() + if self.has_obj: + qp_obj = env.sim.data.qpos[-self.n_obj:].copy() + qv_obj = env.sim.data.qvel[-self.n_obj:].copy() + else: + qp_obj = None + qv_obj = None + self.time = env.sim.data.time + + # Simulate observation noise + if not env.initializing: + qp += robot_noise_ratio*self.robot_pos_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt) + qv += robot_noise_ratio*self.robot_vel_noise_amp[:self.n_jnt]*env.np_random.uniform(low=-1., high=1., size=self.n_jnt) + if self.has_obj: + qp_obj += robot_noise_ratio*self.robot_pos_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj) + qv_obj += robot_noise_ratio*self.robot_vel_noise_amp[-self.n_obj:]*env.np_random.uniform(low=-1., high=1., size=self.n_obj) + + # cache observations + obs = observation(time=self.time, qpos_robot=qp, qvel_robot=qv, qpos_object=qp_obj, qvel_object=qv_obj) + self.observation_cache.append(obs) + + if self.has_obj: + return obs.time, obs.qpos_robot, obs.qvel_robot, obs.qpos_object, obs.qvel_object + else: + return obs.time, obs.qpos_robot, obs.qvel_robot + + + # enforce position specs. + def ctrl_position_limits(self, ctrl_position): + ctrl_feasible_position = np.clip(ctrl_position, self.robot_pos_bound[:self.n_jnt, 0], self.robot_pos_bound[:self.n_jnt, 1]) + return ctrl_feasible_position + + + # step the robot env + def step(self, env, ctrl_desired, step_duration, sim_override=False): + + # Populate observation cache during startup + if env.initializing: + self._observation_cache_refresh(env) + + # enforce velocity limits + ctrl_feasible = self.ctrl_velocity_limits(ctrl_desired, step_duration) + + # enforce position limits + ctrl_feasible = self.ctrl_position_limits(ctrl_feasible) + + # Send controls to the robot + if self.is_hardware and (not sim_override): + raise NotImplementedError() + else: + env.do_simulation(ctrl_feasible, int(step_duration/env.sim.model.opt.timestep)) # render is folded in here + + # Update current robot state on the overlay + if self.overlay: + env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose.copy() + env.sim.forward() + + # synchronize time + if self.is_hardware: + time_now = (time.time()-self.time_start) + time_left_in_step = step_duration - (time_now-self.time) + if(time_left_in_step>0.0001): + time.sleep(time_left_in_step) + return 1 + + + def reset(self, env, reset_pose, reset_vel, overlay_mimic_reset_pose=True, sim_override=False): + reset_pose = self.clip_positions(reset_pose) + + if self.is_hardware: + raise NotImplementedError() + else: + env.sim.reset() + env.sim.data.qpos[:self.n_jnt] = reset_pose[:self.n_jnt].copy() + env.sim.data.qvel[:self.n_jnt] = reset_vel[:self.n_jnt].copy() + if self.has_obj: + env.sim.data.qpos[-self.n_obj:] = reset_pose[-self.n_obj:].copy() + env.sim.data.qvel[-self.n_obj:] = reset_vel[-self.n_obj:].copy() + env.sim.forward() + + if self.overlay: + env.sim.data.qpos[self.n_jnt:2*self.n_jnt] = env.desired_pose[:self.n_jnt].copy() + env.sim.forward() + + # refresh observation cache before exit + self._observation_cache_refresh(env) + + + def close(self): + if self.is_hardware: + cprint("Closing Franka hardware... ", 'white', 'on_grey', end='', flush=True) + status = 0 + raise NotImplementedError() + cprint("Closed (Status: {})".format(status), 'white', 'on_grey', flush=True) + else: + cprint("Closing Franka sim", 'white', 'on_grey', flush=True) + + +class Robot_PosAct(Robot): + + # enforce velocity sepcs. + # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful + def ctrl_velocity_limits(self, ctrl_position, step_duration): + last_obs = self.observation_cache[-1] + ctrl_desired_vel = (ctrl_position-last_obs.qpos_robot[:self.n_jnt])/step_duration + + ctrl_feasible_vel = np.clip(ctrl_desired_vel, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1]) + ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration + return ctrl_feasible_position + + +class Robot_VelAct(Robot): + + # enforce velocity sepcs. + # ALERT: This depends on previous observation. This is not ideal as it breaks MDP addumptions. Be careful + def ctrl_velocity_limits(self, ctrl_velocity, step_duration): + last_obs = self.observation_cache[-1] + + ctrl_feasible_vel = np.clip(ctrl_velocity, self.robot_vel_bound[:self.n_jnt, 0], self.robot_vel_bound[:self.n_jnt, 1]) + ctrl_feasible_position = last_obs.qpos_robot[:self.n_jnt] + ctrl_feasible_vel*step_duration + return ctrl_feasible_position + diff --git a/d4rl/d4rl/kitchen/adept_envs/mujoco_env.py b/d4rl/d4rl/kitchen/adept_envs/mujoco_env.py new file mode 100644 index 0000000..237cff1 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/mujoco_env.py @@ -0,0 +1,200 @@ +"""Base environment for MuJoCo-based environments.""" + +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import os +import time +from typing import Dict, Optional + +import gym +from gym import spaces +from gym.utils import seeding +import numpy as np + +from d4rl.kitchen.adept_envs.simulation.sim_robot import MujocoSimRobot, RenderMode + +DEFAULT_RENDER_SIZE = 480 + +USE_DM_CONTROL = True + + +class MujocoEnv(gym.Env): + """Superclass for all MuJoCo environments.""" + + def __init__(self, + model_path: str, + frame_skip: int, + camera_settings: Optional[Dict] = None, + use_dm_backend: Optional[bool] = None, + ): + """Initializes a new MuJoCo environment. + + Args: + model_path: The path to the MuJoCo XML file. + frame_skip: The number of simulation steps per environment step. On + hardware this influences the duration of each environment step. + camera_settings: Settings to initialize the simulation camera. This + can contain the keys `distance`, `azimuth`, and `elevation`. + use_dm_backend: A boolean to switch between mujoco-py and dm_control. + """ + self._seed() + if not os.path.isfile(model_path): + raise IOError( + '[MujocoEnv]: Model path does not exist: {}'.format(model_path)) + self.frame_skip = frame_skip + + self.sim_robot = MujocoSimRobot( + model_path, + use_dm_backend=use_dm_backend or USE_DM_CONTROL, + camera_settings=camera_settings) + self.sim = self.sim_robot.sim + self.model = self.sim_robot.model + self.data = self.sim_robot.data + + self.metadata = { + 'render.modes': ['human', 'rgb_array', 'depth_array'], + 'video.frames_per_second': int(np.round(1.0 / self.dt)) + } + self.mujoco_render_frames = False + + self.init_qpos = self.data.qpos.ravel().copy() + self.init_qvel = self.data.qvel.ravel().copy() + observation, _reward, done, _info = self.step(np.zeros(self.model.nu)) + assert not done + + bounds = self.model.actuator_ctrlrange.copy() + act_upper = bounds[:, 1] + act_lower = bounds[:, 0] + + # Define the action and observation spaces. + # HACK: MJRL is still using gym 0.9.x so we can't provide a dtype. + try: + self.action_space = spaces.Box( + act_lower, act_upper, dtype=np.float32) + if isinstance(observation, collections.Mapping): + self.observation_space = spaces.Dict({ + k: spaces.Box(-np.inf, np.inf, shape=v.shape, dtype=np.float32) for k, v in observation.items()}) + else: + self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size + self.observation_space = spaces.Box( + -np.inf, np.inf, observation.shape, dtype=np.float32) + + except TypeError: + # Fallback case for gym 0.9.x + self.action_space = spaces.Box(act_lower, act_upper) + assert not isinstance(observation, collections.Mapping), 'gym 0.9.x does not support dictionary observation.' + self.obs_dim = np.sum([o.size for o in observation]) if type(observation) is tuple else observation.size + self.observation_space = spaces.Box( + -np.inf, np.inf, observation.shape) + + def seed(self, seed=None): # Compatibility with new gym + return self._seed(seed) + + def _seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + + # methods to override: + # ---------------------------- + + def reset_model(self): + """Reset the robot degrees of freedom (qpos and qvel). + + Implement this in each subclass. + """ + raise NotImplementedError + + # ----------------------------- + + def reset(self): # compatibility with new gym + return self._reset() + + def _reset(self): + self.sim.reset() + self.sim.forward() + ob = self.reset_model() + return ob + + def set_state(self, qpos, qvel): + assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) + state = self.sim.get_state() + for i in range(self.model.nq): + state.qpos[i] = qpos[i] + for i in range(self.model.nv): + state.qvel[i] = qvel[i] + self.sim.set_state(state) + self.sim.forward() + + @property + def dt(self): + return self.model.opt.timestep * self.frame_skip + + def do_simulation(self, ctrl, n_frames): + for i in range(self.model.nu): + self.sim.data.ctrl[i] = ctrl[i] + + for _ in range(n_frames): + self.sim.step() + + # TODO(michaelahn): Remove this; render should be called separately. + if self.mujoco_render_frames is True: + self.mj_render() + + def render(self, + mode='human', + width=DEFAULT_RENDER_SIZE, + height=DEFAULT_RENDER_SIZE, + camera_id=-1): + """Renders the environment. + + Args: + mode: The type of rendering to use. + - 'human': Renders to a graphical window. + - 'rgb_array': Returns the RGB image as an np.ndarray. + - 'depth_array': Returns the depth image as an np.ndarray. + width: The width of the rendered image. This only affects offscreen + rendering. + height: The height of the rendered image. This only affects + offscreen rendering. + camera_id: The ID of the camera to use. By default, this is the free + camera. If specified, only affects offscreen rendering. + """ + if mode == 'human': + self.sim_robot.renderer.render_to_window() + elif mode == 'rgb_array': + assert width and height + return self.sim_robot.renderer.render_offscreen( + width, height, mode=RenderMode.RGB, camera_id=camera_id) + elif mode == 'depth_array': + assert width and height + return self.sim_robot.renderer.render_offscreen( + width, height, mode=RenderMode.DEPTH, camera_id=camera_id) + else: + raise NotImplementedError(mode) + + def close(self): + self.sim_robot.close() + + def mj_render(self): + """Backwards compatibility with MJRL.""" + self.render(mode='human') + + def state_vector(self): + state = self.sim.get_state() + return np.concatenate([state.qpos.flat, state.qvel.flat]) diff --git a/d4rl/d4rl/kitchen/adept_envs/robot_env.py b/d4rl/d4rl/kitchen/adept_envs/robot_env.py new file mode 100644 index 0000000..4d6e75d --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/robot_env.py @@ -0,0 +1,166 @@ +"""Base class for robotics environments.""" + +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import Dict, Optional + +import numpy as np + + +from d4rl.kitchen.adept_envs import mujoco_env +from d4rl.kitchen.adept_envs.base_robot import BaseRobot +from d4rl.kitchen.adept_envs.utils.configurable import import_class_from_path +from d4rl.kitchen.adept_envs.utils.constants import MODELS_PATH + + +class RobotEnv(mujoco_env.MujocoEnv): + """Base environment for all adept robots.""" + + # Mapping of robot name to fully qualified class path. + # e.g. 'robot': 'adept_envs.dclaw.robot.Robot' + # Subclasses should override this to specify the Robot classes they support. + ROBOTS = {} + + # Mapping of device path to the calibration file to use. If the device path + # is not found, the 'default' key is used. + # This can be overriden by subclasses. + CALIBRATION_PATHS = {} + + def __init__(self, + model_path: str, + robot: BaseRobot, + frame_skip: int, + camera_settings: Optional[Dict] = None): + """Initializes a robotics environment. + + Args: + model_path: The path to the model to run. Relative paths will be + interpreted as relative to the 'adept_models' folder. + robot: The Robot object to use. + frame_skip: The number of simulation steps per environment step. On + hardware this influences the duration of each environment step. + camera_settings: Settings to initialize the simulation camera. This + can contain the keys `distance`, `azimuth`, and `elevation`. + """ + self._robot = robot + + # Initial pose for first step. + self.desired_pose = np.zeros(self.n_jnt) + + if not model_path.startswith('/'): + model_path = os.path.abspath(os.path.join(MODELS_PATH, model_path)) + + self.remote_viz = None + + try: + from adept_envs.utils.remote_viz import RemoteViz + self.remote_viz = RemoteViz(model_path) + except ImportError: + pass + + + self._initializing = True + super(RobotEnv, self).__init__( + model_path, frame_skip, camera_settings=camera_settings) + self._initializing = False + + + @property + def robot(self): + return self._robot + + @property + def n_jnt(self): + return self._robot.n_jnt + + @property + def n_obj(self): + return self._robot.n_obj + + @property + def skip(self): + """Alias for frame_skip. Needed for MJRL.""" + return self.frame_skip + + @property + def initializing(self): + return self._initializing + + def close_env(self): + if self._robot is not None: + self._robot.close() + + def make_robot(self, + n_jnt, + n_obj=0, + is_hardware=False, + device_name=None, + legacy=False, + **kwargs): + """Creates a new robot for the environment. + + Args: + n_jnt: The number of joints in the robot. + n_obj: The number of object joints in the robot environment. + is_hardware: Whether to run on hardware or not. + device_name: The device path for the robot hardware. + legacy: If true, runs using direct dynamixel communication rather + than DDS. + kwargs: See BaseRobot for other parameters. + + Returns: + A Robot object. + """ + if not self.ROBOTS: + raise NotImplementedError('Subclasses must override ROBOTS.') + + if is_hardware and not device_name: + raise ValueError('Must provide device name if running on hardware.') + + robot_name = 'dds_robot' if not legacy and is_hardware else 'robot' + if robot_name not in self.ROBOTS: + raise KeyError("Unsupported robot '{}', available: {}".format( + robot_name, list(self.ROBOTS.keys()))) + + cls = import_class_from_path(self.ROBOTS[robot_name]) + + calibration_path = None + if self.CALIBRATION_PATHS: + if not device_name: + calibration_name = 'default' + elif device_name not in self.CALIBRATION_PATHS: + print('Device "{}" not in CALIBRATION_PATHS; using default.' + .format(device_name)) + calibration_name = 'default' + else: + calibration_name = device_name + + calibration_path = self.CALIBRATION_PATHS[calibration_name] + if not os.path.isfile(calibration_path): + raise OSError('Could not find calibration file at: {}'.format( + calibration_path)) + + return cls( + n_jnt, + n_obj, + is_hardware=is_hardware, + device_name=device_name, + calibration_path=calibration_path, + **kwargs) diff --git a/d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py b/d4rl/d4rl/kitchen/adept_envs/simulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/kitchen/adept_envs/simulation/module.py b/d4rl/d4rl/kitchen/adept_envs/simulation/module.py new file mode 100644 index 0000000..a1284c7 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/simulation/module.py @@ -0,0 +1,126 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for caching Python modules related to simulation.""" + +import sys + +_MUJOCO_PY_MODULE = None + +_DM_MUJOCO_MODULE = None +_DM_VIEWER_MODULE = None +_DM_RENDER_MODULE = None + +_GLFW_MODULE = None + + +def get_mujoco_py(): + """Returns the mujoco_py module.""" + global _MUJOCO_PY_MODULE + if _MUJOCO_PY_MODULE: + return _MUJOCO_PY_MODULE + try: + import mujoco_py + # Override the warning function. + from mujoco_py.builder import cymj + cymj.set_warning_callback(_mj_warning_fn) + except ImportError: + print( + 'Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo ' + 'v1.50) is installed.', + file=sys.stderr) + sys.exit(1) + _MUJOCO_PY_MODULE = mujoco_py + return mujoco_py + + +def get_mujoco_py_mjlib(): + """Returns the mujoco_py mjlib module.""" + + class MjlibDelegate: + """Wrapper that forwards mjlib calls.""" + + def __init__(self, lib): + self._lib = lib + + def __getattr__(self, name: str): + if name.startswith('mj'): + return getattr(self._lib, '_' + name) + raise AttributeError(name) + + return MjlibDelegate(get_mujoco_py().cymj) + + +def get_dm_mujoco(): + """Returns the DM Control mujoco module.""" + global _DM_MUJOCO_MODULE + if _DM_MUJOCO_MODULE: + return _DM_MUJOCO_MODULE + try: + from dm_control import mujoco + except ImportError: + print( + 'Failed to import dm_control.mujoco. Ensure that dm_control (using ' + 'MuJoCo v2.00) is installed.', + file=sys.stderr) + sys.exit(1) + _DM_MUJOCO_MODULE = mujoco + return mujoco + + +def get_dm_viewer(): + """Returns the DM Control viewer module.""" + global _DM_VIEWER_MODULE + if _DM_VIEWER_MODULE: + return _DM_VIEWER_MODULE + try: + from dm_control import viewer + except ImportError: + print( + 'Failed to import dm_control.viewer. Ensure that dm_control (using ' + 'MuJoCo v2.00) is installed.', + file=sys.stderr) + sys.exit(1) + _DM_VIEWER_MODULE = viewer + return viewer + + +def get_dm_render(): + """Returns the DM Control render module.""" + global _DM_RENDER_MODULE + if _DM_RENDER_MODULE: + return _DM_RENDER_MODULE + try: + try: + from dm_control import _render + render = _render + except ImportError: + print('Warning: DM Control is out of date.') + from dm_control import render + except ImportError: + print( + 'Failed to import dm_control.render. Ensure that dm_control (using ' + 'MuJoCo v2.00) is installed.', + file=sys.stderr) + sys.exit(1) + _DM_RENDER_MODULE = render + return render + + +def _mj_warning_fn(warn_data: bytes): + """Warning function override for mujoco_py.""" + print('WARNING: Mujoco simulation is unstable (has NaNs): {}'.format( + warn_data.decode())) diff --git a/d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py b/d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py new file mode 100644 index 0000000..758864a --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/simulation/renderer.py @@ -0,0 +1,293 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for viewing Physics objects in the DM Control viewer.""" + +import abc +import enum +import sys +from typing import Dict, Optional + +import numpy as np + +from d4rl.kitchen.adept_envs.simulation import module + +# Default window dimensions. +DEFAULT_WINDOW_WIDTH = 1024 +DEFAULT_WINDOW_HEIGHT = 768 + +DEFAULT_WINDOW_TITLE = 'MuJoCo Viewer' + +_MAX_RENDERBUFFER_SIZE = 2048 + + +class RenderMode(enum.Enum): + """Rendering modes for offscreen rendering.""" + RGB = 0 + DEPTH = 1 + SEGMENTATION = 2 + + +class Renderer(abc.ABC): + """Base interface for rendering simulations.""" + + def __init__(self, camera_settings: Optional[Dict] = None): + self._camera_settings = camera_settings + + @abc.abstractmethod + def close(self): + """Cleans up any resources being used by the renderer.""" + + @abc.abstractmethod + def render_to_window(self): + """Renders the simulation to a window.""" + + @abc.abstractmethod + def render_offscreen(self, + width: int, + height: int, + mode: RenderMode = RenderMode.RGB, + camera_id: int = -1) -> np.ndarray: + """Renders the camera view as a NumPy array of pixels. + + Args: + width: The viewport width (pixels). + height: The viewport height (pixels). + mode: The rendering mode. + camera_id: The ID of the camera to render from. By default, uses + the free camera. + + Returns: + A NumPy array of the pixels. + """ + + def _update_camera(self, camera): + """Updates the given camera to move to the initial settings.""" + if not self._camera_settings: + return + distance = self._camera_settings.get('distance') + azimuth = self._camera_settings.get('azimuth') + elevation = self._camera_settings.get('elevation') + lookat = self._camera_settings.get('lookat') + + if distance is not None: + camera.distance = distance + if azimuth is not None: + camera.azimuth = azimuth + if elevation is not None: + camera.elevation = elevation + if lookat is not None: + camera.lookat[:] = lookat + + +class MjPyRenderer(Renderer): + """Class for rendering mujoco_py simulations.""" + + def __init__(self, sim, **kwargs): + assert isinstance(sim, module.get_mujoco_py().MjSim), \ + 'MjPyRenderer takes a mujoco_py MjSim object.' + super().__init__(**kwargs) + self._sim = sim + self._onscreen_renderer = None + self._offscreen_renderer = None + + def render_to_window(self): + """Renders the simulation to a window.""" + if not self._onscreen_renderer: + self._onscreen_renderer = module.get_mujoco_py().MjViewer(self._sim) + self._update_camera(self._onscreen_renderer.cam) + + self._onscreen_renderer.render() + + def render_offscreen(self, + width: int, + height: int, + mode: RenderMode = RenderMode.RGB, + camera_id: int = -1) -> np.ndarray: + """Renders the camera view as a NumPy array of pixels. + + Args: + width: The viewport width (pixels). + height: The viewport height (pixels). + mode: The rendering mode. + camera_id: The ID of the camera to render from. By default, uses + the free camera. + + Returns: + A NumPy array of the pixels. + """ + if not self._offscreen_renderer: + self._offscreen_renderer = module.get_mujoco_py() \ + .MjRenderContextOffscreen(self._sim) + + # Update the camera configuration for the free-camera. + if camera_id == -1: + self._update_camera(self._offscreen_renderer.cam) + + self._offscreen_renderer.render(width, height, camera_id) + if mode == RenderMode.RGB: + data = self._offscreen_renderer.read_pixels( + width, height, depth=False) + # Original image is upside-down, so flip it + return data[::-1, :, :] + elif mode == RenderMode.DEPTH: + data = self._offscreen_renderer.read_pixels( + width, height, depth=True)[1] + # Original image is upside-down, so flip it + return data[::-1, :] + else: + raise NotImplementedError(mode) + + def close(self): + """Cleans up any resources being used by the renderer.""" + + +class DMRenderer(Renderer): + """Class for rendering DM Control Physics objects.""" + + def __init__(self, physics, **kwargs): + assert isinstance(physics, module.get_dm_mujoco().Physics), \ + 'DMRenderer takes a DM Control Physics object.' + super().__init__(**kwargs) + self._physics = physics + self._window = None + + # Set the camera to lookat the center of the geoms. (mujoco_py does + # this automatically. + if 'lookat' not in self._camera_settings: + self._camera_settings['lookat'] = [ + np.median(self._physics.data.geom_xpos[:, i]) for i in range(3) + ] + + def render_to_window(self): + """Renders the Physics object to a window. + + The window continuously renders the Physics in a separate thread. + + This function is a no-op if the window was already created. + """ + if not self._window: + self._window = DMRenderWindow() + self._window.load_model(self._physics) + self._update_camera(self._window.camera) + self._window.run_frame() + + def render_offscreen(self, + width: int, + height: int, + mode: RenderMode = RenderMode.RGB, + camera_id: int = -1) -> np.ndarray: + """Renders the camera view as a NumPy array of pixels. + + Args: + width: The viewport width (pixels). + height: The viewport height (pixels). + mode: The rendering mode. + camera_id: The ID of the camera to render from. By default, uses + the free camera. + + Returns: + A NumPy array of the pixels. + """ + mujoco = module.get_dm_mujoco() + # TODO(michaelahn): Consider caching the camera. + camera = mujoco.Camera( + physics=self._physics, + height=height, + width=width, + camera_id=camera_id) + + # Update the camera configuration for the free-camera. + if camera_id == -1: + self._update_camera( + camera._render_camera, # pylint: disable=protected-access + ) + + image = camera.render( + depth=(mode == RenderMode.DEPTH), + segmentation=(mode == RenderMode.SEGMENTATION)) + camera._scene.free() # pylint: disable=protected-access + return image + + def close(self): + """Cleans up any resources being used by the renderer.""" + if self._window: + self._window.close() + self._window = None + + +class DMRenderWindow: + """Class that encapsulates a graphical window.""" + + def __init__(self, + width: int = DEFAULT_WINDOW_WIDTH, + height: int = DEFAULT_WINDOW_HEIGHT, + title: str = DEFAULT_WINDOW_TITLE): + """Creates a graphical render window. + + Args: + width: The width of the window. + height: The height of the window. + title: The title of the window. + """ + dmv = module.get_dm_viewer() + self._viewport = dmv.renderer.Viewport(width, height) + self._window = dmv.gui.RenderWindow(width, height, title) + self._viewer = dmv.viewer.Viewer(self._viewport, self._window.mouse, + self._window.keyboard) + self._draw_surface = None + self._renderer = dmv.renderer.NullRenderer() + + @property + def camera(self): + return self._viewer._camera._camera + + def close(self): + self._viewer.deinitialize() + self._renderer.release() + self._draw_surface.free() + self._window.close() + + def load_model(self, physics): + """Loads the given Physics object to render.""" + self._viewer.deinitialize() + + self._draw_surface = module.get_dm_render().Renderer( + max_width=_MAX_RENDERBUFFER_SIZE, max_height=_MAX_RENDERBUFFER_SIZE) + self._renderer = module.get_dm_viewer().renderer.OffScreenRenderer( + physics.model, self._draw_surface) + + self._viewer.initialize(physics, self._renderer, touchpad=False) + + def run_frame(self): + """Renders one frame of the simulation. + + NOTE: This is extremely slow at the moment. + """ + glfw = module.get_dm_viewer().gui.glfw_gui.glfw + glfw_window = self._window._context.window + if glfw.window_should_close(glfw_window): + sys.exit(0) + + self._viewport.set_size(*self._window.shape) + self._viewer.render() + pixels = self._renderer.pixels + + with self._window._context.make_current() as ctx: + ctx.call(self._window._update_gui_on_render_thread, glfw_window, + pixels) + self._window._mouse.process_events() + self._window._keyboard.process_events() diff --git a/d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py b/d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py new file mode 100644 index 0000000..d319b7e --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/simulation/sim_robot.py @@ -0,0 +1,135 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for loading MuJoCo models.""" + +import os +from typing import Dict, Optional + +from d4rl.kitchen.adept_envs.simulation import module +from d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer, RenderMode + + +class MujocoSimRobot: + """Class that encapsulates a MuJoCo simulation. + + This class exposes methods that are agnostic to the simulation backend. + Two backends are supported: + 1. mujoco_py - MuJoCo v1.50 + 2. dm_control - MuJoCo v2.00 + """ + + def __init__(self, + model_file: str, + use_dm_backend: bool = False, + camera_settings: Optional[Dict] = None): + """Initializes a new simulation. + + Args: + model_file: The MuJoCo XML model file to load. + use_dm_backend: If True, uses DM Control's Physics (MuJoCo v2.0) as + the backend for the simulation. Otherwise, uses mujoco_py (MuJoCo + v1.5) as the backend. + camera_settings: Settings to initialize the renderer's camera. This + can contain the keys `distance`, `azimuth`, and `elevation`. + """ + self._use_dm_backend = use_dm_backend + + if not os.path.isfile(model_file): + raise ValueError( + '[MujocoSimRobot] Invalid model file path: {}'.format( + model_file)) + + if self._use_dm_backend: + dm_mujoco = module.get_dm_mujoco() + if model_file.endswith('.mjb'): + self.sim = dm_mujoco.Physics.from_binary_path(model_file) + else: + self.sim = dm_mujoco.Physics.from_xml_path(model_file) + self.model = self.sim.model + self._patch_mjlib_accessors(self.model, self.sim.data) + self.renderer = DMRenderer( + self.sim, camera_settings=camera_settings) + else: # Use mujoco_py + mujoco_py = module.get_mujoco_py() + self.model = mujoco_py.load_model_from_path(model_file) + self.sim = mujoco_py.MjSim(self.model) + self.renderer = MjPyRenderer( + self.sim, camera_settings=camera_settings) + + self.data = self.sim.data + + def close(self): + """Cleans up any resources being used by the simulation.""" + self.renderer.close() + + def save_binary(self, path: str): + """Saves the loaded model to a binary .mjb file.""" + if os.path.exists(path): + raise ValueError( + '[MujocoSimRobot] Path already exists: {}'.format(path)) + if not path.endswith('.mjb'): + path = path + '.mjb' + if self._use_dm_backend: + self.model.save_binary(path) + else: + with open(path, 'wb') as f: + f.write(self.model.get_mjb()) + + def get_mjlib(self): + """Returns an object that exposes the low-level MuJoCo API.""" + if self._use_dm_backend: + return module.get_dm_mujoco().wrapper.mjbindings.mjlib + else: + return module.get_mujoco_py_mjlib() + + def _patch_mjlib_accessors(self, model, data): + """Adds accessors to the DM Control objects to support mujoco_py API.""" + assert self._use_dm_backend + mjlib = self.get_mjlib() + + def name2id(type_name, name): + obj_id = mjlib.mj_name2id(model.ptr, + mjlib.mju_str2Type(type_name.encode()), + name.encode()) + if obj_id < 0: + raise ValueError('No {} with name "{}" exists.'.format( + type_name, name)) + return obj_id + + if not hasattr(model, 'body_name2id'): + model.body_name2id = lambda name: name2id('body', name) + + if not hasattr(model, 'geom_name2id'): + model.geom_name2id = lambda name: name2id('geom', name) + + if not hasattr(model, 'site_name2id'): + model.site_name2id = lambda name: name2id('site', name) + + if not hasattr(model, 'joint_name2id'): + model.joint_name2id = lambda name: name2id('joint', name) + + if not hasattr(model, 'actuator_name2id'): + model.actuator_name2id = lambda name: name2id('actuator', name) + + if not hasattr(model, 'camera_name2id'): + model.camera_name2id = lambda name: name2id('camera', name) + + if not hasattr(data, 'body_xpos'): + data.body_xpos = data.xpos + + if not hasattr(data, 'body_xquat'): + data.body_xquat = data.xquat diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/__init__.py b/d4rl/d4rl/kitchen/adept_envs/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/config.py b/d4rl/d4rl/kitchen/adept_envs/utils/config.py new file mode 100644 index 0000000..e6aea13 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/utils/config.py @@ -0,0 +1,99 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +try: + import cElementTree as ET +except ImportError: + try: + # Python 2.5 need to import a different module + import xml.etree.cElementTree as ET + except ImportError: + exit_err("Failed to import cElementTree from any known place") + +CONFIG_XML_DATA = """ + + + + + +""" + + +# Read config from root +def read_config_from_node(root_node, parent_name, child_name, dtype=int): + # find parent + parent_node = root_node.find(parent_name) + if parent_node == None: + quit("Parent %s not found" % parent_name) + + # get child data + child_data = parent_node.get(child_name) + if child_data == None: + quit("Child %s not found" % child_name) + + config_val = np.array(child_data.split(), dtype=dtype) + return config_val + + +# get config frlom file or string +def get_config_root_node(config_file_name=None, config_file_data=None): + try: + # get root + if config_file_data is None: + config_file_content = open(config_file_name, "r") + config = ET.parse(config_file_content) + root_node = config.getroot() + else: + root_node = ET.fromstring(config_file_data) + + # get root data + root_data = root_node.get('name') + root_name = np.array(root_data.split(), dtype=str) + except: + quit("ERROR: Unable to process config file %s" % config_file_name) + + return root_node, root_name + + +# Read config from config_file +def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int): + root_node, root_name = get_config_root_node( + config_file_name=config_file_name) + return read_config_from_node(root_node, parent_name, child_name, dtype) + + +# tests +if __name__ == '__main__': + print("Read config and parse -------------------------") + root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA) + print("Root:name \t", root_name) + print("limit:low \t", read_config_from_node(root, "limits", "low", float)) + print("limit:high \t", read_config_from_node(root, "limits", "high", float)) + print("scale:joint \t", read_config_from_node(root, "scale", "joint", + float)) + print("data:type \t", read_config_from_node(root, "data", "type", str)) + + # read straight from xml (dum the XML data as duh.xml for this test) + root, root_name = get_config_root_node(config_file_name="duh.xml") + print("Read from xml --------------------------------") + print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low", + float)) + print("limit:high \t", + read_config_from_xml("duh.xml", "limits", "high", float)) + print("scale:joint \t", + read_config_from_xml("duh.xml", "scale", "joint", float)) + print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str)) diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/configurable.py b/d4rl/d4rl/kitchen/adept_envs/utils/configurable.py new file mode 100644 index 0000000..6685a50 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/utils/configurable.py @@ -0,0 +1,163 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os + +from gym.envs.registration import registry as gym_registry + + +def import_class_from_path(class_path): + """Given 'path.to.module:object', imports and returns the object.""" + module_path, class_name = class_path.split(":") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +class ConfigCache(object): + """Configuration class to store constructor arguments. + + This is used to store parameters to pass to Gym environments at init time. + """ + + def __init__(self): + self._configs = {} + self._default_config = {} + + def set_default_config(self, config): + """Sets the default configuration used for all RobotEnv envs.""" + self._default_config = dict(config) + + def set_config(self, cls_or_env_id, config): + """Sets the configuration for the given environment within a context. + + Args: + cls_or_env_id (Class | str): A class type or Gym environment ID to + configure. + config (dict): The configuration parameters. + """ + config_key = self._get_config_key(cls_or_env_id) + self._configs[config_key] = dict(config) + + def get_config(self, cls_or_env_id): + """Returns the configuration for the given env name. + + Args: + cls_or_env_id (Class | str): A class type or Gym environment ID to + get the configuration of. + """ + config_key = self._get_config_key(cls_or_env_id) + config = dict(self._default_config) + config.update(self._configs.get(config_key, {})) + return config + + def clear_config(self, cls_or_env_id): + """Clears the configuration for the given ID.""" + config_key = self._get_config_key(cls_or_env_id) + if config_key in self._configs: + del self._configs[config_key] + + def _get_config_key(self, cls_or_env_id): + if inspect.isclass(cls_or_env_id): + return cls_or_env_id + env_id = cls_or_env_id + assert isinstance(env_id, str) + if env_id not in gym_registry.env_specs: + raise ValueError("Unregistered environment name {}.".format(env_id)) + entry_point = gym_registry.env_specs[env_id]._entry_point + if callable(entry_point): + return entry_point + else: + return import_class_from_path(entry_point) + + +# Global robot config. +global_config = ConfigCache() + + +def configurable(config_id=None, pickleable=False, config_cache=global_config): + """Class decorator to allow injection of constructor arguments. + + This allows constructor arguments to be passed via ConfigCache. + Example usage: + + @configurable() + class A: + def __init__(b=None, c=2, d='Wow'): + ... + + global_config.set_config(A, {'b': 10, 'c': 20}) + a = A() # b=10, c=20, d='Wow' + a = A(b=30) # b=30, c=20, d='Wow' + + Args: + config_id: ID of the config to use. This defaults to the class type. + pickleable: Whether this class is pickleable. If true, causes the pickle + state to include the config and constructor arguments. + config_cache: The ConfigCache to use to read config data from. Uses + the global ConfigCache by default. + """ + def cls_decorator(cls): + assert inspect.isclass(cls) + + # Overwrite the class constructor to pass arguments from the config. + base_init = cls.__init__ + def __init__(self, *args, **kwargs): + + config = config_cache.get_config(config_id or type(self)) + # Allow kwargs to override the config. + kwargs = {**config, **kwargs} + + # print('Initializing {} with params: {}'.format(type(self).__name__, + # kwargs)) + + if pickleable: + self._pkl_env_args = args + self._pkl_env_kwargs = kwargs + + base_init(self, *args, **kwargs) + cls.__init__ = __init__ + + # If the class is pickleable, overwrite the state methods to save + # the constructor arguments and config. + if pickleable: + # Use same pickle keys as gym.utils.ezpickle for backwards compat. + PKL_ARGS_KEY = '_ezpickle_args' + PKL_KWARGS_KEY = '_ezpickle_kwargs' + + def __getstate__(self): + return { + PKL_ARGS_KEY: self._pkl_env_args, + PKL_KWARGS_KEY: self._pkl_env_kwargs, + } + cls.__getstate__ = __getstate__ + + def __setstate__(self, data): + saved_args = data[PKL_ARGS_KEY] + saved_kwargs = data[PKL_KWARGS_KEY] + + # Override the saved state with the current config. + config = config_cache.get_config(config_id or type(self)) + # Allow kwargs to override the config. + kwargs = {**saved_kwargs, **config} + + inst = type(self)(*saved_args, **kwargs) + self.__dict__.update(inst.__dict__) + cls.__setstate__ = __setstate__ + + return cls + return cls_decorator diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/constants.py b/d4rl/d4rl/kitchen/adept_envs/utils/constants.py new file mode 100644 index 0000000..9c63fb7 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/utils/constants.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +ENVS_ROOT_PATH = os.path.abspath(os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../../")) + +MODELS_PATH = os.path.abspath(os.path.join(ENVS_ROOT_PATH, "../adept_models/")) diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py b/d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py new file mode 100644 index 0000000..01f9c36 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/utils/parse_demos.py @@ -0,0 +1,221 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click +import glob +import pickle +import numpy as np +from parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs +from mjrl.utils.gym_env import GymEnv +import adept_envs +import time as timer +import skvideo.io +import gym + +# headless renderer +render_buffer = [] # rendering buffer + + +def viewer(env, + mode='initialize', + filename='video', + frame_size=(640, 480), + camera_id=0, + render=None): + if render == 'onscreen': + env.mj_render() + + elif render == 'offscreen': + + global render_buffer + if mode == 'initialize': + render_buffer = [] + mode = 'render' + + if mode == 'render': + curr_frame = env.render(mode='rgb_array') + render_buffer.append(curr_frame) + + if mode == 'save': + skvideo.io.vwrite(filename, np.asarray(render_buffer)) + print("\noffscreen buffer saved", filename) + + elif render == 'None': + pass + + else: + print("unknown render: ", render) + + +# view demos (physics ignored) +def render_demos(env, data, filename='demo_rendering.mp4', render=None): + FPS = 30 + render_skip = max(1, round(1. / \ + (FPS * env.sim.model.opt.timestep * env.frame_skip))) + t0 = timer.time() + + viewer(env, mode='initialize', render=render) + for i_frame in range(data['ctrl'].shape[0]): + env.sim.data.qpos[:] = data['qpos'][i_frame].copy() + env.sim.data.qvel[:] = data['qvel'][i_frame].copy() + env.sim.forward() + if i_frame % render_skip == 0: + viewer(env, mode='render', render=render) + print(i_frame, end=', ', flush=True) + + viewer(env, mode='save', filename=filename, render=render) + print("time taken = %f" % (timer.time() - t0)) + + +# playback demos and get data(physics respected) +def gather_training_data(env, data, filename='demo_playback.mp4', render=None): + env = env.env + FPS = 30 + render_skip = max(1, round(1. / \ + (FPS * env.sim.model.opt.timestep * env.frame_skip))) + t0 = timer.time() + + # initialize + env.reset() + init_qpos = data['qpos'][0].copy() + init_qvel = data['qvel'][0].copy() + act_mid = env.act_mid + act_rng = env.act_amp + + # prepare env + env.sim.data.qpos[:] = init_qpos + env.sim.data.qvel[:] = init_qvel + env.sim.forward() + viewer(env, mode='initialize', render=render) + + # step the env and gather data + path_obs = None + for i_frame in range(data['ctrl'].shape[0] - 1): + # Reset every time step + # if i_frame % 1 == 0: + # qp = data['qpos'][i_frame].copy() + # qv = data['qvel'][i_frame].copy() + # env.sim.data.qpos[:] = qp + # env.sim.data.qvel[:] = qv + # env.sim.forward() + + obs = env._get_obs() + + # Construct the action + # ctrl = (data['qpos'][i_frame + 1][:9] - obs[:9]) / (env.skip * env.model.opt.timestep) + ctrl = (data['ctrl'][i_frame] - obs[:9])/(env.skip*env.model.opt.timestep) + act = (ctrl - act_mid) / act_rng + act = np.clip(act, -0.999, 0.999) + next_obs, reward, done, env_info = env.step(act) + if path_obs is None: + path_obs = obs + path_act = act + else: + path_obs = np.vstack((path_obs, obs)) + path_act = np.vstack((path_act, act)) + + # render when needed to maintain FPS + if i_frame % render_skip == 0: + viewer(env, mode='render', render=render) + print(i_frame, end=', ', flush=True) + + # finalize + if render: + viewer(env, mode='save', filename=filename, render=render) + + t1 = timer.time() + print("time taken = %f" % (t1 - t0)) + + # note that are one step away from + return path_obs, path_act, init_qpos, init_qvel + + +# MAIN ========================================================= +@click.command(help="parse tele-op demos") +@click.option('--env', '-e', type=str, help='gym env name', required=True) +@click.option( + '--demo_dir', + '-d', + type=str, + help='directory with tele-op logs', + required=True) +@click.option( + '--skip', + '-s', + type=int, + help='number of frames to skip (1:no skip)', + default=1) +@click.option('--graph', '-g', type=bool, help='plot logs', default=False) +@click.option('--save_logs', '-l', type=bool, help='save logs', default=False) +@click.option( + '--view', '-v', type=str, help='render/playback', default='render') +@click.option( + '--render', '-r', type=str, help='onscreen/offscreen', default='onscreen') +def main(env, demo_dir, skip, graph, save_logs, view, render): + + gym_env = gym.make(env) + paths = [] + print("Scanning demo_dir: " + demo_dir + "=========") + for ind, file in enumerate(glob.glob(demo_dir + "*.mjl")): + + # process logs + print("processing: " + file, end=': ') + + data = parse_mjl_logs(file, skip) + + print("log duration %0.2f" % (data['time'][-1] - data['time'][0])) + + # plot logs + if (graph): + print("plotting: " + file) + viz_parsed_mjl_logs(data) + + # save logs + if (save_logs): + pickle.dump(data, open(file[:-4] + ".pkl", 'wb')) + + # render logs to video + if view == 'render': + render_demos( + gym_env, + data, + filename=data['logName'][:-4] + '_demo_render.mp4', + render=render) + + # playback logs and gather data + elif view == 'playback': + try: + obs, act,init_qpos, init_qvel = gather_training_data(gym_env, data,\ + filename=data['logName'][:-4]+'_playback.mp4', render=render) + except Exception as e: + print(e) + continue + path = { + 'observations': obs, + 'actions': act, + 'goals': obs, + 'init_qpos': init_qpos, + 'init_qvel': init_qvel + } + paths.append(path) + # accept = input('accept demo?') + # if accept == 'n': + # continue + pickle.dump(path, open(demo_dir + env + str(ind) + "_path.pkl", 'wb')) + print(demo_dir + env + file + "_path.pkl") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py b/d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py new file mode 100644 index 0000000..bae531a --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_envs/utils/quatmath.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +# For testing whether a number is close to zero +_FLOAT_EPS = np.finfo(np.float64).eps +_EPS4 = _FLOAT_EPS * 4.0 + + +def mulQuat(qa, qb): + res = np.zeros(4) + res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3] + res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2] + res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1] + res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0] + return res + +def negQuat(quat): + return np.array([quat[0], -quat[1], -quat[2], -quat[3]]) + +def quat2Vel(quat, dt=1): + axis = quat[1:].copy() + sin_a_2 = np.sqrt(np.sum(axis**2)) + axis = axis/(sin_a_2+1e-8) + speed = 2*np.arctan2(sin_a_2, quat[0])/dt + return speed, axis + +def quatDiff2Vel(quat1, quat2, dt): + neg = negQuat(quat1) + diff = mulQuat(quat2, neg) + return quat2Vel(diff, dt) + + +def axis_angle2quat(axis, angle): + c = np.cos(angle/2) + s = np.sin(angle/2) + return np.array([c, s*axis[0], s*axis[1], s*axis[2]]) + +def euler2mat(euler): + """ Convert Euler Angles to Rotation Matrix. See rotation.py for notes """ + euler = np.asarray(euler, dtype=np.float64) + assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler) + + ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0] + si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) + ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) + cc, cs = ci * ck, ci * sk + sc, ss = si * ck, si * sk + + mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64) + mat[..., 2, 2] = cj * ck + mat[..., 2, 1] = sj * sc - cs + mat[..., 2, 0] = sj * cc + ss + mat[..., 1, 2] = cj * sk + mat[..., 1, 1] = sj * ss + cc + mat[..., 1, 0] = sj * cs - sc + mat[..., 0, 2] = -sj + mat[..., 0, 1] = cj * si + mat[..., 0, 0] = cj * ci + return mat + + +def euler2quat(euler): + """ Convert Euler Angles to Quaternions. See rotation.py for notes """ + euler = np.asarray(euler, dtype=np.float64) + assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler) + + ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2 + si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) + ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) + cc, cs = ci * ck, ci * sk + sc, ss = si * ck, si * sk + + quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64) + quat[..., 0] = cj * cc + sj * ss + quat[..., 3] = cj * sc - sj * cs + quat[..., 2] = -(cj * ss + sj * cc) + quat[..., 1] = cj * cs - sj * sc + return quat + + +def mat2euler(mat): + """ Convert Rotation Matrix to Euler Angles. See rotation.py for notes """ + mat = np.asarray(mat, dtype=np.float64) + assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) + + cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2]) + condition = cy > _EPS4 + euler = np.empty(mat.shape[:-1], dtype=np.float64) + euler[..., 2] = np.where(condition, + -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), + -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1])) + euler[..., 1] = np.where(condition, + -np.arctan2(-mat[..., 0, 2], cy), + -np.arctan2(-mat[..., 0, 2], cy)) + euler[..., 0] = np.where(condition, + -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), + 0.0) + return euler + + +def mat2quat(mat): + """ Convert Rotation Matrix to Quaternion. See rotation.py for notes """ + mat = np.asarray(mat, dtype=np.float64) + assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) + + Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2] + Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2] + Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2] + # Fill only lower half of symmetric matrix + K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64) + K[..., 0, 0] = Qxx - Qyy - Qzz + K[..., 1, 0] = Qyx + Qxy + K[..., 1, 1] = Qyy - Qxx - Qzz + K[..., 2, 0] = Qzx + Qxz + K[..., 2, 1] = Qzy + Qyz + K[..., 2, 2] = Qzz - Qxx - Qyy + K[..., 3, 0] = Qyz - Qzy + K[..., 3, 1] = Qzx - Qxz + K[..., 3, 2] = Qxy - Qyx + K[..., 3, 3] = Qxx + Qyy + Qzz + K /= 3.0 + # TODO: vectorize this -- probably could be made faster + q = np.empty(K.shape[:-2] + (4,)) + it = np.nditer(q[..., 0], flags=['multi_index']) + while not it.finished: + # Use Hermitian eigenvectors, values for speed + vals, vecs = np.linalg.eigh(K[it.multi_index]) + # Select largest eigenvector, reorder to w,x,y,z quaternion + q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)] + # Prefer quaternion with positive w + # (q * -1 corresponds to same rotation as q) + if q[it.multi_index][0] < 0: + q[it.multi_index] *= -1 + it.iternext() + return q + + +def quat2euler(quat): + """ Convert Quaternion to Euler Angles. See rotation.py for notes """ + return mat2euler(quat2mat(quat)) + + +def quat2mat(quat): + """ Convert Quaternion to Euler Angles. See rotation.py for notes """ + quat = np.asarray(quat, dtype=np.float64) + assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat) + + w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + Nq = np.sum(quat * quat, axis=-1) + s = 2.0 / Nq + X, Y, Z = x * s, y * s, z * s + wX, wY, wZ = w * X, w * Y, w * Z + xX, xY, xZ = x * X, x * Y, x * Z + yY, yZ, zZ = y * Y, y * Z, z * Z + + mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64) + mat[..., 0, 0] = 1.0 - (yY + zZ) + mat[..., 0, 1] = xY - wZ + mat[..., 0, 2] = xZ + wY + mat[..., 1, 0] = xY + wZ + mat[..., 1, 1] = 1.0 - (xX + zZ) + mat[..., 1, 2] = yZ - wX + mat[..., 2, 0] = xZ - wY + mat[..., 2, 1] = yZ + wX + mat[..., 2, 2] = 1.0 - (xX + yY) + return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3)) \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/adept_models/.gitignore b/d4rl/d4rl/kitchen/adept_models/.gitignore new file mode 100644 index 0000000..b8e8667 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/.gitignore @@ -0,0 +1,8 @@ +# General +.DS_Store +*.swp +*.profraw + +# Editors +.vscode +.idea diff --git a/d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md b/d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md new file mode 100644 index 0000000..db177d4 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/CONTRIBUTING.public.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/d4rl/d4rl/kitchen/adept_models/LICENSE b/d4rl/d4rl/kitchen/adept_models/LICENSE new file mode 100644 index 0000000..9a644b9 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/LICENSE @@ -0,0 +1,203 @@ +Copyright 2019 The DSuite Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/d4rl/d4rl/kitchen/adept_models/README.public.md b/d4rl/d4rl/kitchen/adept_models/README.public.md new file mode 100644 index 0000000..da3fa5d --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/README.public.md @@ -0,0 +1,10 @@ +# D'Suite Scenes + +This repository is based on a collection of [MuJoCo](http://www.mujoco.org/) simulation +scenes and common assets for D'Suite environments. Based on code in the ROBEL suite +https://github.com/google-research/robel + +## Disclaimer + +This is not an official Google product. + diff --git a/d4rl/d4rl/kitchen/adept_models/__init__.py b/d4rl/d4rl/kitchen/adept_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml new file mode 100644 index 0000000..9e1e39d --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_asset.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml new file mode 100644 index 0000000..b76b0da --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/backwall_chain.xml @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml new file mode 100644 index 0000000..c3e28f8 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_asset.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml new file mode 100644 index 0000000..83e1791 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/counters_chain.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml new file mode 100644 index 0000000..8202810 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_asset.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml new file mode 100644 index 0000000..7f935d3 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/hingecabinet_chain.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml new file mode 100644 index 0000000..dbe8e9b --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_asset.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml new file mode 100644 index 0000000..fb5f224 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/kettle_chain.xml @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml new file mode 100644 index 0000000..cc651ee --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_asset.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml new file mode 100644 index 0000000..fd88ab3 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/microwave_chain.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml new file mode 100644 index 0000000..ef1184e --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_asset.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml new file mode 100644 index 0000000..f96f8c7 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/oven_chain.xml @@ -0,0 +1,115 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml new file mode 100644 index 0000000..f0f370a --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_asset.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml new file mode 100644 index 0000000..5aa820e --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/assets/slidecabinet_chain.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml new file mode 100644 index 0000000..69fb889 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/counters.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml new file mode 100644 index 0000000..89b8db4 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/hingecabinet.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml new file mode 100644 index 0000000..a27e978 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/kettle.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml new file mode 100644 index 0000000..34813ca --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/kitchen.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate.stl new file mode 100644 index 0000000..46740b5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate_mesh.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate_mesh.stl new file mode 100644 index 0000000..46740b5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/burnerplate_mesh.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetbase.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetbase.stl new file mode 100644 index 0000000..580a51c Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetbase.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetdrawer.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetdrawer.stl new file mode 100644 index 0000000..0932eeb Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinetdrawer.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinethandle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinethandle.stl new file mode 100644 index 0000000..960cd39 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/cabinethandle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/countertop.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/countertop.stl new file mode 100644 index 0000000..16410d1 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/countertop.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/faucet.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/faucet.stl new file mode 100644 index 0000000..55404af Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/faucet.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/handle2.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/handle2.stl new file mode 100644 index 0000000..09b7833 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/handle2.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingecabinet.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingecabinet.stl new file mode 100644 index 0000000..6693df8 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingecabinet.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingedoor.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingedoor.stl new file mode 100644 index 0000000..feecf23 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingedoor.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingehandle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingehandle.stl new file mode 100644 index 0000000..fb85521 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hingehandle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hood.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hood.stl new file mode 100644 index 0000000..6c0e3ad Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/hood.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettle.stl new file mode 100644 index 0000000..0e8d9e5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettlehandle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettlehandle.stl new file mode 100644 index 0000000..83baef3 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/kettlehandle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/knob.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/knob.stl new file mode 100644 index 0000000..90180b5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/knob.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitch.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitch.stl new file mode 100644 index 0000000..fa956c9 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitch.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitchbase.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitchbase.stl new file mode 100644 index 0000000..e64b059 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/lightswitchbase.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/micro.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/micro.stl new file mode 100644 index 0000000..6ed6802 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/micro.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microbutton.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microbutton.stl new file mode 100644 index 0000000..2d7f1e3 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microbutton.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microdoor.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microdoor.stl new file mode 100644 index 0000000..fa8c548 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microdoor.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microefeet.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microefeet.stl new file mode 100644 index 0000000..98e7069 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microefeet.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microfeet.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microfeet.stl new file mode 100644 index 0000000..a516299 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microfeet.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microhandle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microhandle.stl new file mode 100644 index 0000000..ed31a70 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microhandle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microwindow.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microwindow.stl new file mode 100644 index 0000000..07d3c85 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/microwindow.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oven.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oven.stl new file mode 100644 index 0000000..04d3b66 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oven.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenhandle.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenhandle.stl new file mode 100644 index 0000000..30250a7 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenhandle.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oventop.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oventop.stl new file mode 100644 index 0000000..fb6664d Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/oventop.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenwindow.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenwindow.stl new file mode 100644 index 0000000..f0205a5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/ovenwindow.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidecabinet.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidecabinet.stl new file mode 100644 index 0000000..6249a14 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidecabinet.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidedoor.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidedoor.stl new file mode 100644 index 0000000..307d6c5 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/slidedoor.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/stoverim.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/stoverim.stl new file mode 100644 index 0000000..0f76bfc Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/stoverim.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/tile.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/tile.stl new file mode 100644 index 0000000..12639ce Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/tile.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/wall.stl b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/wall.stl new file mode 100644 index 0000000..f5562e2 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/meshes/wall.stl differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml new file mode 100644 index 0000000..3946632 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/microwave.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml new file mode 100644 index 0000000..6891385 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/oven.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml b/d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml new file mode 100644 index 0000000..78fa599 --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/kitchen/slidecabinet.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/textures/marble1.png b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/marble1.png new file mode 100644 index 0000000..a72c67e Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/marble1.png differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/textures/metal1.png b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/metal1.png new file mode 100644 index 0000000..f16a314 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/metal1.png differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/textures/tile1.png b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/tile1.png new file mode 100644 index 0000000..3e859b4 Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/tile1.png differ diff --git a/d4rl/d4rl/kitchen/adept_models/kitchen/textures/wood1.png b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/wood1.png new file mode 100644 index 0000000..8d2b69b Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/kitchen/textures/wood1.png differ diff --git a/d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml b/d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml new file mode 100644 index 0000000..8d5356d --- /dev/null +++ b/d4rl/d4rl/kitchen/adept_models/scenes/basic_scene.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile.png b/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile.png new file mode 100644 index 0000000..c3f397a Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile.png differ diff --git a/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile2.png b/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile2.png new file mode 100644 index 0000000..00033fc Binary files /dev/null and b/d4rl/d4rl/kitchen/adept_models/scenes/textures/white_marble_tile2.png differ diff --git a/d4rl/d4rl/kitchen/kitchen_envs.py b/d4rl/d4rl/kitchen/kitchen_envs.py new file mode 100644 index 0000000..6888f13 --- /dev/null +++ b/d4rl/d4rl/kitchen/kitchen_envs.py @@ -0,0 +1,98 @@ +"""Environments using kitchen and Franka robot.""" +import os +import numpy as np +from d4rl.kitchen.adept_envs.utils.configurable import configurable +from d4rl.kitchen.adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1 + +from d4rl.offline_env import OfflineEnv + +OBS_ELEMENT_INDICES = { + 'bottom burner': np.array([11, 12]), + 'top burner': np.array([15, 16]), + 'light switch': np.array([17, 18]), + 'slide cabinet': np.array([19]), + 'hinge cabinet': np.array([20, 21]), + 'microwave': np.array([22]), + 'kettle': np.array([23, 24, 25, 26, 27, 28, 29]), + } +OBS_ELEMENT_GOALS = { + 'bottom burner': np.array([-0.88, -0.01]), + 'top burner': np.array([-0.92, -0.01]), + 'light switch': np.array([-0.69, -0.05]), + 'slide cabinet': np.array([0.37]), + 'hinge cabinet': np.array([0., 1.45]), + 'microwave': np.array([-0.75]), + 'kettle': np.array([-0.23, 0.75, 1.62, 0.99, 0., 0., -0.06]), + } +BONUS_THRESH = 0.3 + +@configurable(pickleable=True) +class KitchenBase(KitchenTaskRelaxV1, OfflineEnv): + # A string of element names. The robot's task is then to modify each of + # these elements appropriately. + TASK_ELEMENTS = [] + REMOVE_TASKS_WHEN_COMPLETE = True + TERMINATE_ON_TASK_COMPLETE = True + + def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, **kwargs): + self.tasks_to_complete = set(self.TASK_ELEMENTS) + super(KitchenBase, self).__init__(**kwargs) + OfflineEnv.__init__( + self, + dataset_url=dataset_url, + ref_max_score=ref_max_score, + ref_min_score=ref_min_score) + + def _get_task_goal(self): + new_goal = np.zeros_like(self.goal) + for element in self.TASK_ELEMENTS: + element_idx = OBS_ELEMENT_INDICES[element] + element_goal = OBS_ELEMENT_GOALS[element] + new_goal[element_idx] = element_goal + + return new_goal + + def reset_model(self): + self.tasks_to_complete = set(self.TASK_ELEMENTS) + return super(KitchenBase, self).reset_model() + + def _get_reward_n_score(self, obs_dict): + reward_dict, score = super(KitchenBase, self)._get_reward_n_score(obs_dict) + reward = 0. + next_q_obs = obs_dict['qp'] + next_obj_obs = obs_dict['obj_qp'] + next_goal = obs_dict['goal'] + idx_offset = len(next_q_obs) + completions = [] + for element in self.tasks_to_complete: + element_idx = OBS_ELEMENT_INDICES[element] + distance = np.linalg.norm( + next_obj_obs[..., element_idx - idx_offset] - + next_goal[element_idx]) + complete = distance < BONUS_THRESH + if complete: + completions.append(element) + if self.REMOVE_TASKS_WHEN_COMPLETE: + [self.tasks_to_complete.remove(element) for element in completions] + bonus = float(len(completions)) + reward_dict['bonus'] = bonus + reward_dict['r_total'] = bonus + score = bonus + return reward_dict, score + + def step(self, a, b=None): + obs, reward, done, env_info = super(KitchenBase, self).step(a, b=b) + if self.TERMINATE_ON_TASK_COMPLETE: + done = not self.tasks_to_complete + return obs, reward, done, env_info + + def render(self, mode='human'): + # Disable rendering to speed up environment evaluation. + return [] + + +class KitchenMicrowaveKettleLightSliderV0(KitchenBase): + TASK_ELEMENTS = ['microwave', 'kettle', 'light switch', 'slide cabinet'] + +class KitchenMicrowaveKettleBottomBurnerLightV0(KitchenBase): + TASK_ELEMENTS = ['microwave', 'kettle', 'bottom burner', 'light switch'] diff --git a/d4rl/d4rl/kitchen/third_party/franka/LICENSE b/d4rl/d4rl/kitchen/third_party/franka/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/d4rl/d4rl/kitchen/third_party/franka/README.md b/d4rl/d4rl/kitchen/third_party/franka/README.md new file mode 100644 index 0000000..d96eaf1 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/README.md @@ -0,0 +1,9 @@ +# franka +Franka panda mujoco models + + +# Environment + +franka_panda.xml | comming soon +:-------------------------:|:-------------------------: +![Alt text](franka_panda.png?raw=false "sawyer") | comming soon diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml new file mode 100644 index 0000000..86ee47c --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/actuator0.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml new file mode 100644 index 0000000..a8eda4e --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/actuator1.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml new file mode 100644 index 0000000..4f2cded --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/assets.xml @@ -0,0 +1,63 @@ + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml new file mode 100644 index 0000000..4bb7e70 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/basic_scene.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml new file mode 100644 index 0000000..e2e53a7 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/chain0.xml @@ -0,0 +1,103 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml new file mode 100644 index 0000000..e64f497 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/chain0_overlay.xml @@ -0,0 +1,62 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml new file mode 100644 index 0000000..29a9524 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/chain1.xml @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml b/d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml new file mode 100644 index 0000000..e5e46db --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/assets/teleop_actuator.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml b/d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml new file mode 100644 index 0000000..c307269 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/bi-franka_panda.xml @@ -0,0 +1,81 @@ + + + + + + + + + + + + + + + + + / + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/franka_panda.png b/d4rl/d4rl/kitchen/third_party/franka/franka_panda.png new file mode 100644 index 0000000..c34bec0 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/franka_panda.png differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml b/d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml new file mode 100644 index 0000000..07c5193 --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/franka_panda.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml b/d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml new file mode 100644 index 0000000..cdbf8cd --- /dev/null +++ b/d4rl/d4rl/kitchen/third_party/franka/franka_panda_teleop.xml @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/finger.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/finger.stl new file mode 100644 index 0000000..3b87289 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/finger.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/hand.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/hand.stl new file mode 100644 index 0000000..4e82090 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/hand.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link0.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link0.stl new file mode 100644 index 0000000..def070c Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link0.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link1.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link1.stl new file mode 100644 index 0000000..426bcf2 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link1.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link2.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link2.stl new file mode 100644 index 0000000..b369f15 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link2.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link3.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link3.stl new file mode 100644 index 0000000..25162ee Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link3.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link4.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link4.stl new file mode 100644 index 0000000..76c8c33 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link4.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link5.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link5.stl new file mode 100644 index 0000000..3006a0b Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link5.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link6.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link6.stl new file mode 100644 index 0000000..2e9594a Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link6.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link7.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link7.stl new file mode 100644 index 0000000..0532d05 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/collision/link7.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/finger.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/finger.stl new file mode 100644 index 0000000..2a5a256 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/finger.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/hand.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/hand.stl new file mode 100644 index 0000000..9ecd7f2 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/hand.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link0.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link0.stl new file mode 100644 index 0000000..bf71a18 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link0.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link1.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link1.stl new file mode 100644 index 0000000..6289e56 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link1.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link2.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link2.stl new file mode 100644 index 0000000..5580a80 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link2.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link3.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link3.stl new file mode 100644 index 0000000..cdbe281 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link3.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link4.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link4.stl new file mode 100644 index 0000000..df43017 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link4.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link5.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link5.stl new file mode 100644 index 0000000..9cb5360 Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link5.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link6.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link6.stl new file mode 100644 index 0000000..d43652f Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link6.stl differ diff --git a/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link7.stl b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link7.stl new file mode 100644 index 0000000..6d369ed Binary files /dev/null and b/d4rl/d4rl/kitchen/third_party/franka/meshes/visual/link7.stl differ diff --git a/d4rl/d4rl/locomotion/__init__.py b/d4rl/d4rl/locomotion/__init__.py new file mode 100644 index 0000000..e87d097 --- /dev/null +++ b/d4rl/d4rl/locomotion/__init__.py @@ -0,0 +1,424 @@ +from gym.envs.registration import register +from d4rl.locomotion import ant +from d4rl.locomotion import maze_env + +""" +register( + id='antmaze-umaze-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) +""" + +register( + id='antmaze-umaze-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-umaze-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-medium-play-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-medium-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-large-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-large-play-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-umaze-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_False_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-umaze-diverse-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_umaze_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-medium-play-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-medium-diverse-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_medium_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-large-diverse-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-large-play-v1', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'deprecated': True, + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v1/Ant_maze_large_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-umaze-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'maze_map': maze_env.U_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-umaze-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'maze_map': maze_env.U_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_umaze_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-medium-play-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.BIG_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-medium-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.BIG_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_medium_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-large-diverse-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_False_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + +register( + id='antmaze-eval-large-play-v0', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.HARDEST_MAZE_EVAL_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_large_eval_noisy_multistart_True_multigoal_True_sparse.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + } +) + + +register( + id='antmaze-umaze-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) + +register( + id='antmaze-umaze-diverse-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=700, + kwargs={ + 'maze_map': maze_env.U_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) + +register( + id='antmaze-medium-play-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) + +register( + id='antmaze-medium-diverse-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.BIG_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) + +register( + id='antmaze-large-diverse-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) + +register( + id='antmaze-large-play-v2', + entry_point='d4rl.locomotion.ant:make_ant_maze_env', + max_episode_steps=1000, + kwargs={ + 'maze_map': maze_env.HARDEST_MAZE_TEST, + 'reward_type':'sparse', + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_v2/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse_fixed.hdf5', + 'non_zero_reset':False, + 'eval':True, + 'maze_size_scaling': 4.0, + 'ref_min_score': 0.0, + 'ref_max_score': 1.0, + 'v2_resets': True, + } +) diff --git a/d4rl/d4rl/locomotion/ant.py b/d4rl/d4rl/locomotion/ant.py new file mode 100644 index 0000000..8b1f292 --- /dev/null +++ b/d4rl/d4rl/locomotion/ant.py @@ -0,0 +1,213 @@ +# Copyright 2018 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Wrapper for creating the ant environment.""" + +import math +import numpy as np +import mujoco_py +import os + +from gym import utils +from gym.envs.mujoco import mujoco_env +from d4rl.locomotion import mujoco_goal_env + +from d4rl.locomotion import goal_reaching_env +from d4rl.locomotion import maze_env +from d4rl import offline_env +from d4rl.locomotion import wrappers + +GYM_ASSETS_DIR = os.path.join( + os.path.dirname(mujoco_goal_env.__file__), + 'assets') + +class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): + """Basic ant locomotion environment.""" + FILE = os.path.join(GYM_ASSETS_DIR, 'ant.xml') + + def __init__(self, file_path=None, expose_all_qpos=False, + expose_body_coms=None, expose_body_comvels=None, non_zero_reset=False): + if file_path is None: + file_path = self.FILE + + self._expose_all_qpos = expose_all_qpos + self._expose_body_coms = expose_body_coms + self._expose_body_comvels = expose_body_comvels + self._body_com_indices = {} + self._body_comvel_indices = {} + + self._non_zero_reset = non_zero_reset + + mujoco_env.MujocoEnv.__init__(self, file_path, 5) + utils.EzPickle.__init__(self) + + @property + def physics(self): + # Check mujoco version is greater than version 1.50 to call correct physics + # model containing PyMjData object for getting and setting position/velocity. + # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. + if mujoco_py.get_version() >= '1.50': + return self.sim + else: + return self.model + + def _step(self, a): + return self.step(a) + + def step(self, a): + xposbefore = self.get_body_com("torso")[0] + self.do_simulation(a, self.frame_skip) + xposafter = self.get_body_com("torso")[0] + forward_reward = (xposafter - xposbefore) / self.dt + ctrl_cost = .5 * np.square(a).sum() + contact_cost = 0.5 * 1e-3 * np.sum( + np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) + survive_reward = 1.0 + reward = forward_reward - ctrl_cost - contact_cost + survive_reward + state = self.state_vector() + notdone = np.isfinite(state).all() \ + and state[2] >= 0.2 and state[2] <= 1.0 + done = not notdone + ob = self._get_obs() + return ob, reward, done, dict( + reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + reward_survive=survive_reward) + + def _get_obs(self): + # No cfrc observation. + if self._expose_all_qpos: + obs = np.concatenate([ + self.physics.data.qpos.flat[:15], # Ensures only ant obs. + self.physics.data.qvel.flat[:14], + ]) + else: + obs = np.concatenate([ + self.physics.data.qpos.flat[2:15], + self.physics.data.qvel.flat[:14], + ]) + + if self._expose_body_coms is not None: + for name in self._expose_body_coms: + com = self.get_body_com(name) + if name not in self._body_com_indices: + indices = range(len(obs), len(obs) + len(com)) + self._body_com_indices[name] = indices + obs = np.concatenate([obs, com]) + + if self._expose_body_comvels is not None: + for name in self._expose_body_comvels: + comvel = self.get_body_comvel(name) + if name not in self._body_comvel_indices: + indices = range(len(obs), len(obs) + len(comvel)) + self._body_comvel_indices[name] = indices + obs = np.concatenate([obs, comvel]) + return obs + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform( + size=self.model.nq, low=-.1, high=.1) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 + + if self._non_zero_reset: + """Now the reset is supposed to be to a non-zero location""" + reset_location = self._get_reset_location() + qpos[:2] = reset_location + + # Set everything other than ant to original position and 0 velocity. + qpos[15:] = self.init_qpos[15:] + qvel[14:] = 0. + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.distance = self.model.stat.extent * 0.5 + + def get_xy(self): + return self.physics.data.qpos[:2] + + def set_xy(self, xy): + qpos = np.copy(self.physics.data.qpos) + qpos[0] = xy[0] + qpos[1] = xy[1] + qvel = self.physics.data.qvel + self.set_state(qpos, qvel) + + +class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv): + """Ant locomotion rewarded for goal-reaching.""" + BASE_ENV = AntEnv + + def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, + file_path=None, + expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense', **kwargs): + goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type) + AntEnv.__init__(self, + file_path=file_path, + expose_all_qpos=expose_all_qpos, + expose_body_coms=None, + expose_body_comvels=None, + non_zero_reset=non_zero_reset) + +class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv): + """Ant navigating a maze.""" + LOCOMOTION_ENV = GoalReachingAntEnv + + def __init__(self, goal_sampler=None, expose_all_qpos=True, + reward_type='dense', v2_resets=False, + *args, **kwargs): + if goal_sampler is None: + goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) + maze_env.MazeEnv.__init__( + self, *args, manual_collision=False, + goal_sampler=goal_sampler, + expose_all_qpos=expose_all_qpos, + reward_type=reward_type, + **kwargs) + offline_env.OfflineEnv.__init__(self, **kwargs) + + ## We set the target foal here for evaluation + self.set_target() + self.v2_resets = v2_resets + + def reset(self): + if self.v2_resets: + """ + The target goal for evaluation in antmazes is randomized. + antmazes-v0 and -v1 resulted in really high-variance evaluations + because the target goal was set once at the seed level. This led to + each run running evaluations with one particular goal. To accurately + cover each goal, this requires about 50-100 seeds, which might be + computationally infeasible. As an alternate fix, to reduce variance + in result reporting, we are creating the v2 environments + which use the same offline dataset as v0 environments, with the distinction + that the randomization of goals during evaluation is performed at the level of + each rollout. Thus running a few seeds, but performing the final evaluation + over 100-200 episodes will give a valid estimate of an algorithm's performance. + """ + self.set_target() + return super().reset() + + def set_target(self, target_location=None): + return self.set_target_goal(target_location) + + def seed(self, seed=0): + mujoco_env.MujocoEnv.seed(self, seed) + +def make_ant_maze_env(**kwargs): + env = AntMazeEnv(**kwargs) + return wrappers.NormalizedBoxEnv(env) + diff --git a/d4rl/d4rl/locomotion/assets/ant.xml b/d4rl/d4rl/locomotion/assets/ant.xml new file mode 100644 index 0000000..d39f7c3 --- /dev/null +++ b/d4rl/d4rl/locomotion/assets/ant.xml @@ -0,0 +1,81 @@ + + + diff --git a/d4rl/d4rl/locomotion/assets/point.xml b/d4rl/d4rl/locomotion/assets/point.xml new file mode 100644 index 0000000..d41ade0 --- /dev/null +++ b/d4rl/d4rl/locomotion/assets/point.xml @@ -0,0 +1,30 @@ + + + diff --git a/d4rl/d4rl/locomotion/common.py b/d4rl/d4rl/locomotion/common.py new file mode 100644 index 0000000..56c415c --- /dev/null +++ b/d4rl/d4rl/locomotion/common.py @@ -0,0 +1,21 @@ + + +def run_policy_on_env(policy_fn, env, truncate_episode_at=None, + first_obs=None): + if first_obs is None: + obs = env.reset() + else: + obs = first_obs + + trajectory = [] + step_num = 0 + while True: + act = policy_fn(obs) + next_obs, rew, done, _ = env.step(act) + trajectory.append((obs, act, rew, done)) + obs = next_obs + step_num += 1 + if (done or + (truncate_episode_at is not None and step_num >= truncate_episode_at)): + break + return trajectory diff --git a/d4rl/d4rl/locomotion/generate_dataset.py b/d4rl/d4rl/locomotion/generate_dataset.py new file mode 100644 index 0000000..3101f82 --- /dev/null +++ b/d4rl/d4rl/locomotion/generate_dataset.py @@ -0,0 +1,168 @@ +import numpy as np +import pickle +import gzip +import h5py +import argparse +from d4rl.locomotion import maze_env, ant, swimmer +from d4rl.locomotion.wrappers import NormalizedBoxEnv +from rlkit.torch.pytorch_util import set_gpu_mode +import torch +import skvideo.io +from PIL import Image +import os + + +def reset_data(): + return {'observations': [], + 'actions': [], + 'terminals': [], + 'rewards': [], + 'infos/goal': [], + 'infos/qpos': [], + 'infos/qvel': [], + } + +def append_data(data, s, a, r, tgt, done, env_data): + data['observations'].append(s) + data['actions'].append(a) + data['rewards'].append(r) + data['terminals'].append(done) + data['infos/goal'].append(tgt) + data['infos/qpos'].append(env_data.qpos.ravel().copy()) + data['infos/qvel'].append(env_data.qvel.ravel().copy()) + +def npify(data): + for k in data: + if k == 'terminals': + dtype = np.bool_ + else: + dtype = np.float32 + + data[k] = np.array(data[k], dtype=dtype) + +def load_policy(policy_file): + data = torch.load(policy_file) + policy = data['exploration/policy'] + env = data['evaluation/env'] + print("Policy loaded") + if True: + set_gpu_mode(True) + policy.cuda() + return policy, env + +def save_video(save_dir, file_name, frames, episode_id=0): + filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id)) + if not os.path.exists(filename): + os.makedirs(filename) + num_frames = frames.shape[0] + for i in range(num_frames): + img = Image.fromarray(np.flipud(frames[i]), 'RGB') + img.save(os.path.join(filename, 'frame_{}.png'.format(i))) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--noisy', action='store_true', help='Noisy actions') + parser.add_argument('--maze', type=str, default='u-maze', help='Maze type. small or default') + parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') + parser.add_argument('--env', type=str, default='Ant', help='Environment type') + parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name') + parser.add_argument('--max_episode_steps', default=1000, type=int) + parser.add_argument('--video', action='store_true') + parser.add_argument('--multi_start', action='store_true') + parser.add_argument('--multigoal', action='store_true') + args = parser.parse_args() + + if args.maze == 'u-maze': + maze = maze_env.U_MAZE + elif args.maze == 'big-maze': + maze = maze_env.BIG_MAZE + elif args.maze == 'hardest-maze': + maze = maze_env.HARDEST_MAZE + else: + raise NotImplementedError + + if args.env == 'Ant': + env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) + elif args.env == 'Swimmer': + env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) + + env.set_target_goal() + s = env.reset() + print (s.shape) + act = env.action_space.sample() + done = False + + # Load the policy + policy, train_env = load_policy(args.policy_file) + + # Define goal reaching policy fn + def _goal_reaching_policy_fn(obs, goal): + goal_x, goal_y = goal + obs_new = obs[2:-2] + goal_tuple = np.array([goal_x, goal_y]) + + # normalize the norm of the relative goals to in-distribution values + goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0 + + new_obs = np.concatenate([obs_new, goal_tuple], -1) + return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1]) + + data = reset_data() + + # create waypoint generating policy integrated with high level controller + data_collection_policy = env.create_navigation_policy( + _goal_reaching_policy_fn, + ) + + if args.video: + frames = [] + + ts = 0 + num_episodes = 0 + for _ in range(args.num_samples): + act, waypoint_goal = data_collection_policy(s) + + if args.noisy: + act = act + np.random.randn(*act.shape)*0.2 + act = np.clip(act, -1.0, 1.0) + + ns, r, done, info = env.step(act) + if ts >= args.max_episode_steps: + done = True + + append_data(data, s[:-2], act, r, env.target_goal, done, env.physics.data) + + if len(data['observations']) % 10000 == 0: + print(len(data['observations'])) + + ts += 1 + + if done: + done = False + ts = 0 + s = env.reset() + env.set_target_goal() + if args.video: + frames = np.array(frames) + save_video('./videos/', args.env + '_navigation', frames, num_episodes) + + num_episodes += 1 + frames = [] + else: + s = ns + + if args.video: + curr_frame = env.physics.render(width=500, height=500, depth=False) + frames.append(curr_frame) + + if args.noisy: + fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) + else: + fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) + dataset = h5py.File(fname, 'w') + npify(data) + for k in data: + dataset.create_dataset(k, data=data[k], compression='gzip') + +if __name__ == '__main__': + main() diff --git a/d4rl/d4rl/locomotion/goal_reaching_env.py b/d4rl/d4rl/locomotion/goal_reaching_env.py new file mode 100644 index 0000000..eef1569 --- /dev/null +++ b/d4rl/d4rl/locomotion/goal_reaching_env.py @@ -0,0 +1,58 @@ +import numpy as np + + +def disk_goal_sampler(np_random, goal_region_radius=10.): + th = 2 * np.pi * np_random.uniform() + radius = goal_region_radius * np_random.uniform() + return radius * np.array([np.cos(th), np.sin(th)]) + +def constant_goal_sampler(np_random, location=10.0 * np.ones([2])): + return location + +class GoalReachingEnv(object): + """General goal-reaching environment.""" + BASE_ENV = None # Must be specified by child class. + + def __init__(self, goal_sampler, eval=False, reward_type='dense'): + self._goal_sampler = goal_sampler + self._goal = np.ones([2]) + self.target_goal = self._goal + + # This flag is used to make sure that when using this environment + # for evaluation, that is no goals are appended to the state + self.eval = eval + + # This is the reward type fed as input to the goal confitioned policy + self.reward_type = reward_type + + def _get_obs(self): + base_obs = self.BASE_ENV._get_obs(self) + goal_direction = self._goal - self.get_xy() + if not self.eval: + obs = np.concatenate([base_obs, goal_direction]) + return obs + else: + return base_obs + + def step(self, a): + self.BASE_ENV.step(self, a) + if self.reward_type == 'dense': + reward = -np.linalg.norm(self.target_goal - self.get_xy()) + elif self.reward_type == 'sparse': + reward = 1.0 if np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5 else 0.0 + + done = False + # Terminate episode when we reach a goal + if self.eval and np.linalg.norm(self.get_xy() - self.target_goal) <= 0.5: + done = True + + obs = self._get_obs() + return obs, reward, done, {} + + def reset_model(self): + if self.target_goal is not None or self.eval: + self._goal = self.target_goal + else: + self._goal = self._goal_sampler(self.np_random) + + return self.BASE_ENV.reset_model(self) \ No newline at end of file diff --git a/d4rl/d4rl/locomotion/maze_env.py b/d4rl/d4rl/locomotion/maze_env.py new file mode 100644 index 0000000..c6010f2 --- /dev/null +++ b/d4rl/d4rl/locomotion/maze_env.py @@ -0,0 +1,377 @@ +# Copyright 2018 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adapted from efficient-hrl maze_env.py.""" + +import os +import tempfile +import xml.etree.ElementTree as ET +import math +import numpy as np +import gym +from copy import deepcopy + +RESET = R = 'r' # Reset position. +GOAL = G = 'g' + +# Maze specifications for dataset generation +U_MAZE = [[1, 1, 1, 1, 1], + [1, R, 0, 0, 1], + [1, 1, 1, 0, 1], + [1, G, 0, 0, 1], + [1, 1, 1, 1, 1]] + +BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 1, 1, 0, 0, 1], + [1, 0, 0, 1, 0, 0, G, 1], + [1, 1, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 1, 0, 0, 0, 1], + [1, G, 1, 0, 0, 1, 0, 1], + [1, 0, 0, 0, 1, G, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 0, 0, 1, G, 0, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [1, 0, 0, 0, 0, G, 0, 1, 0, 0, G, 1], + [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1], + [1, 0, G, 1, 0, 1, 0, 0, 0, 0, 0, 1], + [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], + [1, 0, 0, 1, G, 0, G, 1, 0, G, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + +# Maze specifications with a single target goal +U_MAZE_TEST = [[1, 1, 1, 1, 1], + [1, R, 0, 0, 1], + [1, 1, 1, 0, 1], + [1, G, 0, 0, 1], + [1, 1, 1, 1, 1]] + +BIG_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 1, 1, 0, 0, 1], + [1, 0, 0, 1, 0, 0, 0, 1], + [1, 1, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 1, 0, 0, 1, 0, 1], + [1, 0, 0, 0, 1, 0, G, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +HARDEST_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1], + [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], + [1, 0, 0, 1, 0, 0, 0, 1, 0, G, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + +# Maze specifications for evaluation +U_MAZE_EVAL = [[1, 1, 1, 1, 1], + [1, 0, 0, R, 1], + [1, 0, 1, 1, 1], + [1, 0, 0, G, 1], + [1, 1, 1, 1, 1]] + +BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 0, 0, 0, G, 1], + [1, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 0, 0, 1, 0, 1], + [1, 1, 1, 0, 0, 1, 1, 1], + [1, G, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 1, 1, G, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +HARDEST_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 1, G, 0, 0, 1, 0, G, 0, 1], + [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1], + [1, 0, 0, 1, 0, 1, G, 0, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1], + [1, G, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1], + [1, 0, 0, 0, G, 1, G, 0, 0, 0, G, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + +U_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1], + [1, 0, 0, R, 1], + [1, 0, 1, 1, 1], + [1, 0, 0, G, 1], + [1, 1, 1, 1, 1]] + +BIG_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 0, 0, 0, G, 1], + [1, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 0, 0, 1, 0, 1], + [1, 1, 1, 0, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 1, 1, 0, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +HARDEST_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 1, 0, 0, 0, 1, 0, G, 0, 1], + [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1], + [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + + +class MazeEnv(gym.Env): + LOCOMOTION_ENV = None # Must be specified by child class. + + def __init__( + self, + maze_map, + maze_size_scaling, + maze_height=0.5, + manual_collision=False, + non_zero_reset=False, + reward_type='dense', + *args, + **kwargs): + if self.LOCOMOTION_ENV is None: + raise ValueError('LOCOMOTION_ENV is unspecified.') + + xml_path = self.LOCOMOTION_ENV.FILE + tree = ET.parse(xml_path) + worldbody = tree.find(".//worldbody") + + self._maze_map = maze_map + + self._maze_height = maze_height + self._maze_size_scaling = maze_size_scaling + self._manual_collision = manual_collision + + self._maze_map = maze_map + + # Obtain a numpy array form for a maze map in case we want to reset + # to multiple starting states + temp_maze_map = deepcopy(self._maze_map) + for i in range(len(maze_map)): + for j in range(len(maze_map[0])): + if temp_maze_map[i][j] in [RESET,]: + temp_maze_map[i][j] = 0 + elif temp_maze_map[i][j] in [GOAL,]: + temp_maze_map[i][j] = 1 + + self._np_maze_map = np.array(temp_maze_map) + + torso_x, torso_y = self._find_robot() + self._init_torso_x = torso_x + self._init_torso_y = torso_y + + for i in range(len(self._maze_map)): + for j in range(len(self._maze_map[0])): + struct = self._maze_map[i][j] + if struct == 1: # Unmovable block. + # Offset all coordinates so that robot starts at the origin. + ET.SubElement( + worldbody, "geom", + name="block_%d_%d" % (i, j), + pos="%f %f %f" % (j * self._maze_size_scaling - torso_x, + i * self._maze_size_scaling - torso_y, + self._maze_height / 2 * self._maze_size_scaling), + size="%f %f %f" % (0.5 * self._maze_size_scaling, + 0.5 * self._maze_size_scaling, + self._maze_height / 2 * self._maze_size_scaling), + type="box", + material="", + contype="1", + conaffinity="1", + rgba="0.7 0.5 0.3 1.0", + ) + + torso = tree.find(".//body[@name='torso']") + geoms = torso.findall(".//geom") + + _, file_path = tempfile.mkstemp(text=True, suffix='.xml') + tree.write(file_path) + + self.LOCOMOTION_ENV.__init__(self, *args, file_path=file_path, non_zero_reset=non_zero_reset, reward_type=reward_type, **kwargs) + + self.target_goal = None + + def _xy_to_rowcol(self, xy): + size_scaling = self._maze_size_scaling + xy = (max(xy[0], 1e-4), max(xy[1], 1e-4)) + return (int(1 + (xy[1]) / size_scaling), + int(1 + (xy[0]) / size_scaling)) + + def _get_reset_location(self,): + prob = (1.0 - self._np_maze_map) / np.sum(1.0 - self._np_maze_map) + prob_row = np.sum(prob, 1) + row_sample = np.random.choice(np.arange(self._np_maze_map.shape[0]), p=prob_row) + col_sample = np.random.choice(np.arange(self._np_maze_map.shape[1]), p=prob[row_sample] * 1.0 / prob_row[row_sample]) + reset_location = self._rowcol_to_xy((row_sample, col_sample)) + + # Add some random noise + random_x = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling + random_y = np.random.uniform(low=0, high=0.5) * 0.5 * self._maze_size_scaling + + return (max(reset_location[0] + random_x, 0), max(reset_location[1] + random_y, 0)) + + def _rowcol_to_xy(self, rowcol, add_random_noise=False): + row, col = rowcol + x = col * self._maze_size_scaling - self._init_torso_x + y = row * self._maze_size_scaling - self._init_torso_y + if add_random_noise: + x = x + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25) + y = y + np.random.uniform(low=0, high=self._maze_size_scaling * 0.25) + return (x, y) + + def goal_sampler(self, np_random, only_free_cells=True, interpolate=True): + valid_cells = [] + goal_cells = [] + + for i in range(len(self._maze_map)): + for j in range(len(self._maze_map[0])): + if self._maze_map[i][j] in [0, RESET, GOAL] or not only_free_cells: + valid_cells.append((i, j)) + if self._maze_map[i][j] == GOAL: + goal_cells.append((i, j)) + + # If there is a 'goal' designated, use that. Otherwise, any valid cell can + # be a goal. + sample_choices = goal_cells if goal_cells else valid_cells + cell = sample_choices[np_random.choice(len(sample_choices))] + xy = self._rowcol_to_xy(cell, add_random_noise=True) + + random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling + random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling + + xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0)) + + return xy + + def set_target_goal(self, goal_input=None): + if goal_input is None: + self.target_goal = self.goal_sampler(np.random) + else: + self.target_goal = goal_input + + print ('Target Goal: ', self.target_goal) + ## Make sure that the goal used in self._goal is also reset: + self._goal = self.target_goal + + def _find_robot(self): + structure = self._maze_map + size_scaling = self._maze_size_scaling + for i in range(len(structure)): + for j in range(len(structure[0])): + if structure[i][j] == RESET: + return j * size_scaling, i * size_scaling + raise ValueError('No robot in maze specification.') + + def _is_in_collision(self, pos): + x, y = pos + structure = self._maze_map + size_scaling = self._maze_size_scaling + for i in range(len(structure)): + for j in range(len(structure[0])): + if structure[i][j] == 1: + minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x + maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x + miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y + maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y + if minx <= x <= maxx and miny <= y <= maxy: + return True + return False + + def step(self, action): + if self._manual_collision: + old_pos = self.get_xy() + inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action) + new_pos = self.get_xy() + if self._is_in_collision(new_pos): + self.set_xy(old_pos) + else: + inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(self, action) + next_obs = self._get_obs() + return next_obs, inner_reward, done, info + + def _get_best_next_rowcol(self, current_rowcol, target_rowcol): + """Runs BFS to find shortest path to target and returns best next rowcol. + Add obstacle avoidance""" + current_rowcol = tuple(current_rowcol) + target_rowcol = tuple(target_rowcol) + if target_rowcol == current_rowcol: + return target_rowcol + + visited = {} + to_visit = [target_rowcol] + while to_visit: + next_visit = [] + for rowcol in to_visit: + visited[rowcol] = True + row, col = rowcol + left = (row, col - 1) + right = (row, col + 1) + down = (row + 1, col) + up = (row - 1, col) + for next_rowcol in [left, right, down, up]: + if next_rowcol == current_rowcol: # Found a shortest path. + return rowcol + next_row, next_col = next_rowcol + if next_row < 0 or next_row >= len(self._maze_map): + continue + if next_col < 0 or next_col >= len(self._maze_map[0]): + continue + if self._maze_map[next_row][next_col] not in [0, RESET, GOAL]: + continue + if next_rowcol in visited: + continue + next_visit.append(next_rowcol) + to_visit = next_visit + + raise ValueError('No path found to target.') + + def create_navigation_policy(self, + goal_reaching_policy_fn, + obs_to_robot=lambda obs: obs[:2], + obs_to_target=lambda obs: obs[-2:], + relative=False): + """Creates a navigation policy by guiding a sub-policy to waypoints.""" + + def policy_fn(obs): + # import ipdb; ipdb.set_trace() + robot_x, robot_y = obs_to_robot(obs) + robot_row, robot_col = self._xy_to_rowcol([robot_x, robot_y]) + target_x, target_y = self.target_goal + if relative: + target_x += robot_x # Target is given in relative coordinates. + target_y += robot_y + target_row, target_col = self._xy_to_rowcol([target_x, target_y]) + print ('Target: ', target_row, target_col, target_x, target_y) + print ('Robot: ', robot_row, robot_col, robot_x, robot_y) + + waypoint_row, waypoint_col = self._get_best_next_rowcol( + [robot_row, robot_col], [target_row, target_col]) + + if waypoint_row == target_row and waypoint_col == target_col: + waypoint_x = target_x + waypoint_y = target_y + else: + waypoint_x, waypoint_y = self._rowcol_to_xy([waypoint_row, waypoint_col], add_random_noise=True) + + goal_x = waypoint_x - robot_x + goal_y = waypoint_y - robot_y + + print ('Waypoint: ', waypoint_row, waypoint_col, waypoint_x, waypoint_y) + + return goal_reaching_policy_fn(obs, (goal_x, goal_y)) + + return policy_fn diff --git a/d4rl/d4rl/locomotion/mujoco_goal_env.py b/d4rl/d4rl/locomotion/mujoco_goal_env.py new file mode 100644 index 0000000..714facb --- /dev/null +++ b/d4rl/d4rl/locomotion/mujoco_goal_env.py @@ -0,0 +1,191 @@ +from collections import OrderedDict +import os + + +from gym import error, spaces +from gym.utils import seeding +import numpy as np +from os import path +import gym + +try: + import mujoco_py +except ImportError as e: + raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e)) + +DEFAULT_SIZE = 500 + +def convert_observation_to_space(observation): + if isinstance(observation, dict): + space = spaces.Dict(OrderedDict([ + (key, convert_observation_to_space(value)) + for key, value in observation.items() + ])) + elif isinstance(observation, np.ndarray): + low = np.full(observation.shape, -float('inf'), dtype=np.float32) + high = np.full(observation.shape, float('inf'), dtype=np.float32) + space = spaces.Box(low, high, dtype=observation.dtype) + else: + raise NotImplementedError(type(observation), observation) + + return space + +class MujocoGoalEnv(gym.Env): + """SuperClass for all MuJoCo goal reaching environments""" + + def __init__(self, model_path, frame_skip): + if model_path.startswith("/"): + fullpath = model_path + else: + fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) + if not path.exists(fullpath): + raise IOError("File %s does not exist" % fullpath) + self.frame_skip = frame_skip + self.model = mujoco_py.load_model_from_path(fullpath) + self.sim = mujoco_py.MjSim(self.model) + self.data = self.sim.data + self.viewer = None + self._viewers = {} + + self.metadata = { + 'render.modes': ['human', 'rgb_array', 'depth_array'], + 'video.frames_per_second': int(np.round(1.0 / self.dt)) + } + + self.init_qpos = self.sim.data.qpos.ravel().copy() + self.init_qvel = self.sim.data.qvel.ravel().copy() + + self._set_action_space() + + action = self.action_space.sample() + # import ipdb; ipdb.set_trace() + observation, _reward, done, _info = self.step(action) + assert not done + + self._set_observation_space(observation['observation']) + + self.seed() + + def _set_action_space(self): + bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) + low, high = bounds.T + self.action_space = spaces.Box(low=low, high=high, dtype=np.float32) + return self.action_space + + # def _set_observation_space(self, observation): + # self.observation_space = convert_observation_to_space(observation) + # return self.observation_space + + def _set_observation_space(self, observation): + temp_observation_space = convert_observation_to_space(observation) + self.observation_space = spaces.Dict(dict( + observation=temp_observation_space, + desired_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32), + achieved_goal=spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32), + )) + return self.observation_space + + def seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + + # methods to override: + # ---------------------------- + + def reset_model(self): + """ + Reset the robot degrees of freedom (qpos and qvel). + Implement this in each subclass. + """ + raise NotImplementedError + + def viewer_setup(self): + """ + This method is called when the viewer is initialized. + Optionally implement this method, if you need to tinker with camera position + and so forth. + """ + pass + + def reset(self): + self.sim.reset() + ob = self.reset_model() + return ob + + def set_state(self, qpos, qvel): + assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) + old_state = self.sim.get_state() + new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, + old_state.act, old_state.udd_state) + self.sim.set_state(new_state) + self.sim.forward() + + @property + def dt(self): + return self.model.opt.timestep * self.frame_skip + + def do_simulation(self, ctrl, n_frames): + self.sim.data.ctrl[:] = ctrl + for _ in range(n_frames): + self.sim.step() + + def render(self, + mode='human', + width=DEFAULT_SIZE, + height=DEFAULT_SIZE, + camera_id=None, + camera_name=None): + if mode == 'rgb_array': + if camera_id is not None and camera_name is not None: + raise ValueError("Both `camera_id` and `camera_name` cannot be" + " specified at the same time.") + + no_camera_specified = camera_name is None and camera_id is None + if no_camera_specified: + camera_name = 'track' + + if camera_id is None and camera_name in self.model._camera_name2id: + camera_id = self.model.camera_name2id(camera_name) + + self._get_viewer(mode).render(width, height, camera_id=camera_id) + # window size used for old mujoco-py: + data = self._get_viewer(mode).read_pixels(width, height, depth=False) + # original image is upside-down, so flip it + return data[::-1, :, :] + elif mode == 'depth_array': + self._get_viewer(mode).render(width, height) + # window size used for old mujoco-py: + # Extract depth part of the read_pixels() tuple + data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1] + # original image is upside-down, so flip it + return data[::-1, :] + elif mode == 'human': + self._get_viewer(mode).render() + + def close(self): + if self.viewer is not None: + # self.viewer.finish() + self.viewer = None + self._viewers = {} + + def _get_viewer(self, mode): + self.viewer = self._viewers.get(mode) + if self.viewer is None: + if mode == 'human': + self.viewer = mujoco_py.MjViewer(self.sim) + elif mode == 'rgb_array' or mode == 'depth_array': + self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1) + + self.viewer_setup() + self._viewers[mode] = self.viewer + return self.viewer + + def get_body_com(self, body_name): + return self.data.get_body_xpos(body_name) + + def state_vector(self): + return np.concatenate([ + self.sim.data.qpos.flat, + self.sim.data.qvel.flat + ]) + diff --git a/d4rl/d4rl/locomotion/point.py b/d4rl/d4rl/locomotion/point.py new file mode 100644 index 0000000..fbff03b --- /dev/null +++ b/d4rl/d4rl/locomotion/point.py @@ -0,0 +1,196 @@ +# Copyright 2018 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Wrapper for creating the point environment.""" + +import math +import numpy as np +import mujoco_py +import os + +from gym import utils +from gym.envs.mujoco import mujoco_env +from d4rl.locomotion import mujoco_goal_env + +from d4rl.locomotion import goal_reaching_env +from d4rl.locomotion import maze_env + +MY_ASSETS_DIR = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'assets') + + +class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): + FILE = os.path.join(MY_ASSETS_DIR, 'point.xml') + + def __init__(self, file_path=None, expose_all_qpos=False): + if file_path is None: + file_path = self.FILE + + self._expose_all_qpos = expose_all_qpos + + mujoco_env.MujocoEnv.__init__(self, file_path, 1) + # mujoco_goal_env.MujocoGoalEnv.__init__(self, file_path, 1) + utils.EzPickle.__init__(self) + + @property + def physics(self): + # Check mujoco version is greater than version 1.50 to call correct physics + # model containing PyMjData object for getting and setting position/velocity. + # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. + if mujoco_py.get_version() >= '1.50': + return self.sim + else: + return self.model + + def _step(self, a): + return self.step(a) + + def step(self, action): + action[0] = 0.2 * action[0] + qpos = np.copy(self.physics.data.qpos) + qpos[2] += action[1] + ori = qpos[2] + # Compute increment in each direction. + dx = math.cos(ori) * action[0] + dy = math.sin(ori) * action[0] + # Ensure that the robot is within reasonable range. + qpos[0] = np.clip(qpos[0] + dx, -100, 100) + qpos[1] = np.clip(qpos[1] + dy, -100, 100) + qvel = self.physics.data.qvel + self.set_state(qpos, qvel) + for _ in range(0, self.frame_skip): + self.physics.step() + next_obs = self._get_obs() + reward = 0 + done = False + info = {} + return next_obs, reward, done, info + + def _get_obs(self): + if self._expose_all_qpos: + return np.concatenate([ + self.physics.data.qpos.flat[:3], # Only point-relevant coords. + self.physics.data.qvel.flat[:3]]) + return np.concatenate([ + self.physics.data.qpos.flat[2:3], + self.physics.data.qvel.flat[:3]]) + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform( + size=self.physics.model.nq, low=-.1, high=.1) + qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1 + + # Set everything other than point to original position and 0 velocity. + qpos[3:] = self.init_qpos[3:] + qvel[3:] = 0. + self.set_state(qpos, qvel) + return self._get_obs() + + def get_xy(self): + return self.physics.data.qpos[:2] + + def set_xy(self, xy): + qpos = np.copy(self.physics.data.qpos) + qpos[0] = xy[0] + qpos[1] = xy[1] + qvel = self.physics.data.qvel + self.set_state(qpos, qvel) + + +class GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv): + """Point locomotion rewarded for goal-reaching.""" + BASE_ENV = PointEnv + + def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, + file_path=None, + expose_all_qpos=False): + goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler) + PointEnv.__init__(self, + file_path=file_path, + expose_all_qpos=expose_all_qpos) + +class GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, PointEnv): + """Ant locomotion for goal reaching in a disctionary compatible format.""" + BASE_ENV = PointEnv + + def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, + file_path=None, + expose_all_qpos=False): + goal_reaching_env.GoalReachingDictEnv.__init__(self, goal_sampler) + PointEnv.__init__(self, + file_path=file_path, + expose_all_qpos=expose_all_qpos) + +class PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv): + """Point navigating a maze.""" + LOCOMOTION_ENV = GoalReachingPointEnv + + def __init__(self, goal_sampler=None, expose_all_qpos=True, + *args, **kwargs): + if goal_sampler is None: + goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) + maze_env.MazeEnv.__init__( + self, *args, manual_collision=True, + goal_sampler=goal_sampler, + expose_all_qpos=expose_all_qpos, + **kwargs) + + +def create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:], + obs_to_ori=lambda obs: obs[0]): + """A hard-coded policy for reaching a goal position.""" + + def policy_fn(obs): + goal_x, goal_y = obs_to_goal(obs) + goal_dist = np.linalg.norm([goal_x, goal_y]) + goal_ori = np.arctan2(goal_y, goal_x) + ori = obs_to_ori(obs) + ori_diff = (goal_ori - ori) % (2 * np.pi) + + radius = goal_dist / 2. / max(0.1, np.abs(np.sin(ori_diff))) + rotation_left = (2 * ori_diff) % np.pi + circumference_left = max(goal_dist, radius * rotation_left) + + speed = min(circumference_left * 5., 1.0) + velocity = speed + if ori_diff > np.pi / 2 and ori_diff < 3 * np.pi / 2: + velocity *= -1 + + time_left = min(circumference_left / (speed * 0.2), 10.) + signed_ori_diff = ori_diff + if signed_ori_diff >= 3 * np.pi / 2: + signed_ori_diff = 2 * np.pi - signed_ori_diff + elif signed_ori_diff > np.pi / 2 and signed_ori_diff < 3 * np.pi / 2: + signed_ori_diff = signed_ori_diff - np.pi + + angular_velocity = signed_ori_diff / time_left + angular_velocity = np.clip(angular_velocity, -1., 1.) + + return np.array([velocity, angular_velocity]) + + return policy_fn + + +def create_maze_navigation_policy(maze_env): + """Creates a hard-coded policy to navigate a maze.""" + ori_index = 2 if maze_env._expose_all_qpos else 0 + obs_to_ori = lambda obs: obs[ori_index] + + goal_reaching_policy = create_goal_reaching_policy(obs_to_ori=obs_to_ori) + goal_reaching_policy_fn = lambda obs, goal: goal_reaching_policy( + np.concatenate([obs, goal])) + + return maze_env.create_navigation_policy(goal_reaching_policy_fn) diff --git a/d4rl/d4rl/locomotion/swimmer.py b/d4rl/d4rl/locomotion/swimmer.py new file mode 100644 index 0000000..bb8282a --- /dev/null +++ b/d4rl/d4rl/locomotion/swimmer.py @@ -0,0 +1,125 @@ +"""Wrapper for creating the swimmer environment.""" + +import math +import numpy as np +import mujoco_py +import os + +from gym import utils +from gym.envs.mujoco import mujoco_env +from d4rl.locomotion import mujoco_goal_env + +from d4rl.locomotion import goal_reaching_env +from d4rl.locomotion import maze_env +from d4rl import offline_env + +GYM_ASSETS_DIR = os.path.join( + os.path.dirname(mujoco_env.__file__), + 'assets') + + +class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): + """Basic swimmer locomotion environment.""" + FILE = os.path.join(GYM_ASSETS_DIR, 'swimmer.xml') + + def __init__(self, file_path=None, expose_all_qpos=False, non_zero_reset=False): + if file_path is None: + file_path = self.FILE + + self._expose_all_qpos = expose_all_qpos + + mujoco_env.MujocoEnv.__init__(self, file_path, 5) + utils.EzPickle.__init__(self) + + @property + def physics(self): + # Check mujoco version is greater than version 1.50 to call correct physics + # model containing PyMjData object for getting and setting position/velocity. + # Check https://github.com/openai/mujoco-py/issues/80 for updates to api. + if mujoco_py.get_version() >= '1.50': + return self.sim + else: + return self.model + + def _step(self, a): + return self.step(a) + + def step(self, a): + ctrl_cost_coeff = 0.0001 + xposbefore = self.sim.data.qpos[0] + self.do_simulation(a, self.frame_skip) + xposafter = self.sim.data.qpos[0] + reward_fwd = (xposafter - xposbefore) / self.dt + reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() + reward = reward_fwd + reward_ctrl + ob = self._get_obs() + return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) + + def _get_obs(self): + if self._expose_all_qpos: + obs = np.concatenate([ + self.physics.data.qpos.flat[:5], # Ensures only swimmer obs. + self.physics.data.qvel.flat[:5], + ]) + else: + obs = np.concatenate([ + self.physics.data.qpos.flat[2:5], + self.physics.data.qvel.flat[:5], + ]) + + return obs + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform( + size=self.model.nq, low=-.1, high=.1) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 + + # Set everything other than swimmer to original position and 0 velocity. + qpos[5:] = self.init_qpos[5:] + qvel[5:] = 0. + self.set_state(qpos, qvel) + return self._get_obs() + + def get_xy(self): + return self.physics.data.qpos[:2] + + def set_xy(self, xy): + qpos = np.copy(self.physics.data.qpos) + qpos[0] = xy[0] + qpos[1] = xy[1] + qvel = self.physics.data.qvel + self.set_state(qpos, qvel) + + +class GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerEnv): + """Swimmer locomotion rewarded for goal-reaching.""" + BASE_ENV = SwimmerEnv + + def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler, + file_path=None, + expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type="dense", **kwargs): + goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type) + SwimmerEnv.__init__(self, + file_path=file_path, + expose_all_qpos=expose_all_qpos, + non_zero_reset=non_zero_reset) + +class SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv, offline_env.OfflineEnv): + """Swimmer navigating a maze.""" + LOCOMOTION_ENV = GoalReachingSwimmerEnv + + def __init__(self, goal_sampler=None, expose_all_qpos=True, + reward_type='dense', + *args, **kwargs): + if goal_sampler is None: + goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand) + maze_env.MazeEnv.__init__( + self, *args, manual_collision=False, + goal_sampler=goal_sampler, + expose_all_qpos=expose_all_qpos, + reward_type=reward_type, + **kwargs) + offline_env.OfflineEnv.__init__(self, **kwargs) + + def set_target(self, target_location=None): + return self.set_target_goal(target_location) diff --git a/d4rl/d4rl/locomotion/wrappers.py b/d4rl/d4rl/locomotion/wrappers.py new file mode 100644 index 0000000..45b371c --- /dev/null +++ b/d4rl/d4rl/locomotion/wrappers.py @@ -0,0 +1,168 @@ +import numpy as np +import itertools +from gym import Env +from gym.spaces import Box +from gym.spaces import Discrete + +from collections import deque + + +class ProxyEnv(Env): + def __init__(self, wrapped_env): + self._wrapped_env = wrapped_env + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + return self._wrapped_env.reset(**kwargs) + + def step(self, action): + return self._wrapped_env.step(action) + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self.wrapped_env.terminate() + + def __getattr__(self, attr): + if attr == '_wrapped_env': + raise AttributeError() + return getattr(self._wrapped_env, attr) + + def __getstate__(self): + """ + This is useful to override in case the wrapped env has some funky + __getstate__ that doesn't play well with overriding __getattr__. + + The main problematic case is/was gym's EzPickle serialization scheme. + :return: + """ + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return '{}({})'.format(type(self).__name__, self.wrapped_env) + + +class HistoryEnv(ProxyEnv, Env): + def __init__(self, wrapped_env, history_len): + super().__init__(wrapped_env) + self.history_len = history_len + + high = np.inf * np.ones( + self.history_len * self.observation_space.low.size) + low = -high + self.observation_space = Box(low=low, + high=high, + ) + self.history = deque(maxlen=self.history_len) + + def step(self, action): + state, reward, done, info = super().step(action) + self.history.append(state) + flattened_history = self._get_history().flatten() + return flattened_history, reward, done, info + + def reset(self, **kwargs): + state = super().reset() + self.history = deque(maxlen=self.history_len) + self.history.append(state) + flattened_history = self._get_history().flatten() + return flattened_history + + def _get_history(self): + observations = list(self.history) + + obs_count = len(observations) + for _ in range(self.history_len - obs_count): + dummy = np.zeros(self._wrapped_env.observation_space.low.size) + observations.append(dummy) + return np.c_[observations] + + +class DiscretizeEnv(ProxyEnv, Env): + def __init__(self, wrapped_env, num_bins): + super().__init__(wrapped_env) + low = self.wrapped_env.action_space.low + high = self.wrapped_env.action_space.high + action_ranges = [ + np.linspace(low[i], high[i], num_bins) + for i in range(len(low)) + ] + self.idx_to_continuous_action = [ + np.array(x) for x in itertools.product(*action_ranges) + ] + self.action_space = Discrete(len(self.idx_to_continuous_action)) + + def step(self, action): + continuous_action = self.idx_to_continuous_action[action] + return super().step(continuous_action) + + +class NormalizedBoxEnv(ProxyEnv): + """ + Normalize action to in [-1, 1]. + + Optionally normalize observations and scale reward. + """ + + def __init__( + self, + env, + reward_scale=1., + obs_mean=None, + obs_std=None, + ): + ProxyEnv.__init__(self, env) + self._should_normalize = not (obs_mean is None and obs_std is None) + if self._should_normalize: + if obs_mean is None: + obs_mean = np.zeros_like(env.observation_space.low) + else: + obs_mean = np.array(obs_mean) + if obs_std is None: + obs_std = np.ones_like(env.observation_space.low) + else: + obs_std = np.array(obs_std) + self._reward_scale = reward_scale + self._obs_mean = obs_mean + self._obs_std = obs_std + ub = np.ones(self._wrapped_env.action_space.shape) + self.action_space = Box(-1 * ub, ub) + + def estimate_obs_stats(self, obs_batch, override_values=False): + if self._obs_mean is not None and not override_values: + raise Exception("Observation mean and std already set. To " + "override, set override_values to True.") + self._obs_mean = np.mean(obs_batch, axis=0) + self._obs_std = np.std(obs_batch, axis=0) + + def _apply_normalize_obs(self, obs): + return (obs - self._obs_mean) / (self._obs_std + 1e-8) + + def step(self, action): + lb = self._wrapped_env.action_space.low + ub = self._wrapped_env.action_space.high + scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) + scaled_action = np.clip(scaled_action, lb, ub) + + wrapped_step = self._wrapped_env.step(scaled_action) + next_obs, reward, done, info = wrapped_step + if self._should_normalize: + next_obs = self._apply_normalize_obs(next_obs) + return next_obs, reward * self._reward_scale, done, info + + def __str__(self): + return "Normalized: %s" % self._wrapped_env diff --git a/d4rl/d4rl/offline_env.py b/d4rl/d4rl/offline_env.py new file mode 100644 index 0000000..c438405 --- /dev/null +++ b/d4rl/d4rl/offline_env.py @@ -0,0 +1,154 @@ +import os +import urllib.request +import warnings + +import gym +from gym.utils import colorize +import h5py +from tqdm import tqdm + + +def set_dataset_path(path): + global DATASET_PATH + DATASET_PATH = path + os.makedirs(path, exist_ok=True) + + +set_dataset_path(os.environ.get('D4RL_DATASET_DIR', os.path.expanduser('~/.d4rl/datasets'))) + + +def get_keys(h5file): + keys = [] + + def visitor(name, item): + if isinstance(item, h5py.Dataset): + keys.append(name) + + h5file.visititems(visitor) + return keys + + +def filepath_from_url(dataset_url): + _, dataset_name = os.path.split(dataset_url) + dataset_filepath = os.path.join(DATASET_PATH, dataset_name) + return dataset_filepath + + +def download_dataset_from_url(dataset_url): + dataset_filepath = filepath_from_url(dataset_url) + if not os.path.exists(dataset_filepath): + print('Downloading dataset:', dataset_url, 'to', dataset_filepath) + urllib.request.urlretrieve(dataset_url, dataset_filepath) + if not os.path.exists(dataset_filepath): + raise IOError("Failed to download dataset from %s" % dataset_url) + return dataset_filepath + + +class OfflineEnv(gym.Env): + """ + Base class for offline RL envs. + + Args: + dataset_url: URL pointing to the dataset. + ref_max_score: Maximum score (for score normalization) + ref_min_score: Minimum score (for score normalization) + deprecated: If True, will display a warning that the environment is deprecated. + """ + + def __init__(self, dataset_url=None, ref_max_score=None, ref_min_score=None, + deprecated=False, deprecation_message=None, **kwargs): + super(OfflineEnv, self).__init__(**kwargs) + self.dataset_url = self._dataset_url = dataset_url + self.ref_max_score = ref_max_score + self.ref_min_score = ref_min_score + if deprecated: + if deprecation_message is None: + deprecation_message = "This environment is deprecated. Please use the most recent version of this environment." + # stacklevel=2 will bump the warning to the superclass. + warnings.warn(colorize(deprecation_message, 'yellow'), stacklevel=2) + + + def get_normalized_score(self, score): + if (self.ref_max_score is None) or (self.ref_min_score is None): + raise ValueError("Reference score not provided for env") + return (score - self.ref_min_score) / (self.ref_max_score - self.ref_min_score) + + @property + def dataset_filepath(self): + return filepath_from_url(self.dataset_url) + + def get_dataset(self, h5path=None): + if h5path is None: + if self._dataset_url is None: + raise ValueError("Offline env not configured with a dataset URL.") + h5path = download_dataset_from_url(self.dataset_url) + + data_dict = {} + with h5py.File(h5path, 'r') as dataset_file: + for k in tqdm(get_keys(dataset_file), desc="load datafile"): + try: # first try loading as an array + data_dict[k] = dataset_file[k][:] + except ValueError as e: # try loading as a scalar + data_dict[k] = dataset_file[k][()] + + # Run a few quick sanity checks + for key in ['observations', 'actions', 'rewards', 'terminals']: + assert key in data_dict, 'Dataset is missing key %s' % key + N_samples = data_dict['observations'].shape[0] + if self.observation_space.shape is not None: + assert data_dict['observations'].shape[1:] == self.observation_space.shape, \ + 'Observation shape does not match env: %s vs %s' % ( + str(data_dict['observations'].shape[1:]), str(self.observation_space.shape)) + assert data_dict['actions'].shape[1:] == self.action_space.shape, \ + 'Action shape does not match env: %s vs %s' % ( + str(data_dict['actions'].shape[1:]), str(self.action_space.shape)) + if data_dict['rewards'].shape == (N_samples, 1): + data_dict['rewards'] = data_dict['rewards'][:, 0] + assert data_dict['rewards'].shape == (N_samples,), 'Reward has wrong shape: %s' % ( + str(data_dict['rewards'].shape)) + if data_dict['terminals'].shape == (N_samples, 1): + data_dict['terminals'] = data_dict['terminals'][:, 0] + assert data_dict['terminals'].shape == (N_samples,), 'Terminals has wrong shape: %s' % ( + str(data_dict['rewards'].shape)) + return data_dict + + def get_dataset_chunk(self, chunk_id, h5path=None): + """ + Returns a slice of the full dataset. + + Args: + chunk_id (int): An integer representing which slice of the dataset to return. + + Returns: + A dictionary containing observtions, actions, rewards, and terminals. + """ + if h5path is None: + if self._dataset_url is None: + raise ValueError("Offline env not configured with a dataset URL.") + h5path = download_dataset_from_url(self.dataset_url) + + dataset_file = h5py.File(h5path, 'r') + + if 'virtual' not in dataset_file.keys(): + raise ValueError('Dataset is not a chunked dataset') + available_chunks = [int(_chunk) for _chunk in list(dataset_file['virtual'].keys())] + if chunk_id not in available_chunks: + raise ValueError('Chunk id not found: %d. Available chunks: %s' % (chunk_id, str(available_chunks))) + + load_keys = ['observations', 'actions', 'rewards', 'terminals'] + data_dict = {k: dataset_file['virtual/%d/%s' % (chunk_id, k)][:] for k in load_keys} + dataset_file.close() + return data_dict + + +class OfflineEnvWrapper(gym.Wrapper, OfflineEnv): + """ + Wrapper class for offline RL envs. + """ + + def __init__(self, env, **kwargs): + gym.Wrapper.__init__(self, env) + OfflineEnv.__init__(self, **kwargs) + + def reset(self): + return self.env.reset() diff --git a/d4rl/d4rl/ope.py b/d4rl/d4rl/ope.py new file mode 100644 index 0000000..85658eb --- /dev/null +++ b/d4rl/d4rl/ope.py @@ -0,0 +1,132 @@ +""" +Metrics for off-policy evaluation. +""" +from d4rl import infos +import numpy as np + + +UNDISCOUNTED_POLICY_RETURNS = { + 'halfcheetah-medium' : 3985.8150261686337, + 'halfcheetah-random' : -199.26067391425954, + 'halfcheetah-expert' : 12330.945945279545, + 'hopper-medium' : 2260.1983114487352, + 'hopper-random' : 1257.9757846810203, + 'hopper-expert' : 3624.4696022560997, + 'walker2d-medium' : 2760.3310101980005, + 'walker2d-random' : 896.4751989935487, + 'walker2d-expert' : 4005.89370727539, +} + + +DISCOUNTED_POLICY_RETURNS = { + 'halfcheetah-medium' : 324.83583782709877, + 'halfcheetah-random' : -16.836944753939207, + 'halfcheetah-expert' : 827.7278887047698, + 'hopper-medium' : 235.7441494727478, + 'hopper-random' : 215.04955086664955, + 'hopper-expert' : 271.6925087260701, + 'walker2d-medium' : 202.23983424823822, + 'walker2d-random' : 78.46052021427765, + 'walker2d-expert' : 396.8752247768766 +} + + +def get_returns(policy_id, discounted=False): + if discounted: + return DISCOUNTED_POLICY_RETURNS[policy_id] + return UNDISCOUNTED_POLICY_RETURNS[policy_id] + + +def normalize(policy_id, score): + key = policy_id + '-v0' + min_score = infos.REF_MIN_SCORE[key] + max_score = infos.REF_MAX_SCORE[key] + return (score - min_score) / (max_score - min_score) + + +def ranking_correlation_metric(policies, discounted=False): + """ + Computes Spearman's rank correlation coefficient. + A score of 1.0 means the policies are ranked correctly according to their values. + A score of -1.0 means the policies are ranked inversely. + + Args: + policies: A list of policy string identifiers. + Valid identifiers must be contained in POLICY_RETURNS. + + Returns: + A correlation value between [-1, 1] + """ + return_values = np.array([get_returns(policy_key, discounted=discounted) for policy_key in policies]) + ranks = np.argsort(-return_values) + N = len(policies) + diff = ranks - np.arange(N) + return 1.0 - (6 * np.sum(diff ** 2)) / (N * (N**2 - 1)) + + +def precision_at_k_metric(policies, k=1, n_rel=None, discounted=False): + """ + Computes precision@k. + + Args: + policies: A list of policy string identifiers. + k (int): Number of top items. + n_rel (int): Number of relevant items. Default is k. + + Returns: + Fraction of top k policies in the top n_rel of the true rankings. + """ + assert len(policies) >= k + if n_rel is None: + n_rel = k + top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel] + policy_k = policies[:k] + score = sum([policy in top_k for policy in policy_k]) + return float(score) / k + + +def recall_at_k_metric(policies, k=1, n_rel=None, discounted=False): + """ + Computes recall@k. + + Args: + policies: A list of policy string identifiers. + k (int): Number of top items. + n_rel (int): Number of relevant items. Default is k. + + Returns: + Fraction of top n_rel true policy rankings in the top k of the given policies + """ + assert len(policies) >= k + if n_rel is None: + n_rel = k + top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel] + policy_k = policies[:k] + score = sum([policy in policy_k for policy in top_k]) + return float(score) / k + + +def value_error_metric(policy, value, discounted=False): + """ + Returns the absolute error in estimated value. + + Args: + policy (str): A policy string identifier. + value (float): Estimated value + """ + return abs(normalize(policy, value) - normalize(policy, get_returns(policy, discounted))) + + +def policy_regret_metric(policy, expert_policies, discounted=False): + """ + Returns the regret of the given policy against a set of expert policies. + + Args: + policy (str): A policy string identifier. + expert_policies (list[str]): A list of expert policies + Returns: + The regret, which is value of the best expert minus the value of the policy. + """ + best_returns = max([get_returns(policy_key, discounted=discounted) for policy_key in expert_policies]) + return normalize(policy, best_returns) - normalize(policy, get_returns(policy, discounted=discounted)) + diff --git a/d4rl/d4rl/pointmaze/__init__.py b/d4rl/d4rl/pointmaze/__init__.py new file mode 100644 index 0000000..8892874 --- /dev/null +++ b/d4rl/d4rl/pointmaze/__init__.py @@ -0,0 +1,290 @@ +from .maze_model import MazeEnv, OPEN, U_MAZE, MEDIUM_MAZE, LARGE_MAZE, U_MAZE_EVAL, MEDIUM_MAZE_EVAL, LARGE_MAZE_EVAL +from gym.envs.registration import register + +register( + id='maze2d-open-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=150, + kwargs={ + 'maze_spec':OPEN, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 0.01, + 'ref_max_score': 20.66, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5' + } +) + +register( + id='maze2d-umaze-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=150, + kwargs={ + 'maze_spec':U_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 0.94, + 'ref_max_score': 62.6, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse.hdf5' + } +) + +register( + id='maze2d-medium-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=250, + kwargs={ + 'maze_spec':MEDIUM_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 5.77, + 'ref_max_score': 85.14, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse.hdf5' + } +) + + +register( + id='maze2d-large-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':LARGE_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 4.83, + 'ref_max_score': 191.99, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse.hdf5' + } +) + + +register( + id='maze2d-umaze-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=300, + kwargs={ + 'maze_spec':U_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 23.85, + 'ref_max_score': 161.86, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5' + } +) + +register( + id='maze2d-medium-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':MEDIUM_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 13.13, + 'ref_max_score': 277.39, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5' + } +) + + +register( + id='maze2d-large-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=800, + kwargs={ + 'maze_spec':LARGE_MAZE, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 6.7, + 'ref_max_score': 273.99, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5' + } +) + +register( + id='maze2d-eval-umaze-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=300, + kwargs={ + 'maze_spec':U_MAZE_EVAL, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 36.63, + 'ref_max_score': 141.4, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5' + } +) + +register( + id='maze2d-eval-medium-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':MEDIUM_MAZE_EVAL, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 13.07, + 'ref_max_score': 204.93, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5' + } +) + + +register( + id='maze2d-eval-large-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=800, + kwargs={ + 'maze_spec':LARGE_MAZE_EVAL, + 'reward_type':'sparse', + 'reset_target': False, + 'ref_min_score': 16.4, + 'ref_max_score': 302.22, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5' + } +) + + +register( + id='maze2d-open-dense-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=150, + kwargs={ + 'maze_spec':OPEN, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 11.17817, + 'ref_max_score': 27.166538620695782, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5' + } +) + +register( + id='maze2d-umaze-dense-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=150, + kwargs={ + 'maze_spec':U_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 23.249793, + 'ref_max_score': 81.78995240126592, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense.hdf5' + } +) + +register( + id='maze2d-medium-dense-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=250, + kwargs={ + 'maze_spec':MEDIUM_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 19.477620, + 'ref_max_score': 96.03474232952358, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense.hdf5' + } +) + + +register( + id='maze2d-large-dense-v0', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':LARGE_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 27.388310, + 'ref_max_score': 215.09965671563742, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense.hdf5' + } +) + +register( + id='maze2d-umaze-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=300, + kwargs={ + 'maze_spec':U_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 68.537689, + 'ref_max_score': 193.66285642381482, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5' + } +) + +register( + id='maze2d-medium-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':MEDIUM_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 44.264742, + 'ref_max_score': 297.4552547777125, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5' + } +) + + +register( + id='maze2d-large-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=800, + kwargs={ + 'maze_spec':LARGE_MAZE, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 30.569041, + 'ref_max_score': 303.4857382709002, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5' + } +) + +register( + id='maze2d-eval-umaze-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=300, + kwargs={ + 'maze_spec':U_MAZE_EVAL, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 56.95455, + 'ref_max_score': 178.21373133248397, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5' + } +) + +register( + id='maze2d-eval-medium-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=600, + kwargs={ + 'maze_spec':MEDIUM_MAZE_EVAL, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 42.28578, + 'ref_max_score': 235.5658957482388, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5' + } +) + + +register( + id='maze2d-eval-large-dense-v1', + entry_point='d4rl.pointmaze:MazeEnv', + max_episode_steps=800, + kwargs={ + 'maze_spec':LARGE_MAZE_EVAL, + 'reward_type':'dense', + 'reset_target': False, + 'ref_min_score': 56.95455, + 'ref_max_score': 326.09647655082637, + 'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5' + } +) diff --git a/d4rl/d4rl/pointmaze/dynamic_mjc.py b/d4rl/d4rl/pointmaze/dynamic_mjc.py new file mode 100644 index 0000000..657554d --- /dev/null +++ b/d4rl/d4rl/pointmaze/dynamic_mjc.py @@ -0,0 +1,138 @@ +""" +dynamic_mjc.py +A small library for programatically building MuJoCo XML files +""" +from contextlib import contextmanager +import tempfile +import numpy as np + + +def default_model(name): + """ + Get a model with basic settings such as gravity and RK4 integration enabled + """ + model = MJCModel(name) + root = model.root + + # Setup + root.compiler(angle="radian", inertiafromgeom="true") + default = root.default() + default.joint(armature=1, damping=1, limited="true") + default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1') + root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01) + return model + +def pointmass_model(name): + """ + Get a model with basic settings such as gravity and Euler integration enabled + """ + model = MJCModel(name) + root = model.root + + # Setup + root.compiler(angle="radian", inertiafromgeom="true", coordinate="local") + default = root.default() + default.joint(limited="false", damping=1) + default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002") + root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler") + return model + + +class MJCModel(object): + def __init__(self, name): + self.name = name + self.root = MJCTreeNode("mujoco").add_attr('model', name) + + @contextmanager + def asfile(self): + """ + Usage: + model = MJCModel('reacher') + with model.asfile() as f: + print f.read() # prints a dump of the model + """ + with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f: + self.root.write(f) + f.seek(0) + yield f + + def open(self): + self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) + self.root.write(self.file) + self.file.seek(0) + return self.file + + def close(self): + self.file.close() + + def find_attr(self, attr, value): + return self.root.find_attr(attr, value) + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + +class MJCTreeNode(object): + def __init__(self, name): + self.name = name + self.attrs = {} + self.children = [] + + def add_attr(self, key, value): + if isinstance(value, str): + pass + elif isinstance(value, list) or isinstance(value, np.ndarray): + value = ' '.join([str(val).lower() for val in value]) + else: + value = str(value).lower() + + self.attrs[key] = value + return self + + def __getattr__(self, name): + def wrapper(**kwargs): + newnode = MJCTreeNode(name) + for (k, v) in kwargs.items(): + newnode.add_attr(k, v) + self.children.append(newnode) + return newnode + return wrapper + + def dfs(self): + yield self + if self.children: + for child in self.children: + for node in child.dfs(): + yield node + + def find_attr(self, attr, value): + """ Run DFS to find a matching attr """ + if attr in self.attrs and self.attrs[attr] == value: + return self + for child in self.children: + res = child.find_attr(attr, value) + if res is not None: + return res + return None + + + def write(self, ostream, tabs=0): + contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) + if self.children: + ostream.write('\t'*tabs) + ostream.write('<%s %s>\n' % (self.name, contents)) + for child in self.children: + child.write(ostream, tabs=tabs+1) + ostream.write('\t'*tabs) + ostream.write('\n' % self.name) + else: + ostream.write('\t'*tabs) + ostream.write('<%s %s/>\n' % (self.name, contents)) + + def __str__(self): + s = "<"+self.name + s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) + return s+">" diff --git a/d4rl/d4rl/pointmaze/gridcraft/__init__.py b/d4rl/d4rl/pointmaze/gridcraft/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/pointmaze/gridcraft/grid_env.py b/d4rl/d4rl/pointmaze/gridcraft/grid_env.py new file mode 100644 index 0000000..372bd09 --- /dev/null +++ b/d4rl/d4rl/pointmaze/gridcraft/grid_env.py @@ -0,0 +1,210 @@ +import sys +import numpy as np +import gym +import gym.spaces + +from d4rl.pointmaze.gridcraft.grid_spec import REWARD, REWARD2, REWARD3, REWARD4, WALL, LAVA, TILES, START, RENDER_DICT +from d4rl.pointmaze.gridcraft.utils import one_hot_to_flat, flat_to_one_hot + +ACT_NOOP = 0 +ACT_UP = 1 +ACT_DOWN = 2 +ACT_LEFT = 3 +ACT_RIGHT = 4 +ACT_DICT = { + ACT_NOOP: [0,0], + ACT_UP: [0, -1], + ACT_LEFT: [-1, 0], + ACT_RIGHT: [+1, 0], + ACT_DOWN: [0, +1] +} +ACT_TO_STR = { + ACT_NOOP: 'NOOP', + ACT_UP: 'UP', + ACT_LEFT: 'LEFT', + ACT_RIGHT: 'RIGHT', + ACT_DOWN: 'DOWN' +} + +class TransitionModel(object): + def __init__(self, gridspec, eps=0.2): + self.gs = gridspec + self.eps = eps + + def get_aprobs(self, s, a): + # TODO: could probably output a matrix over all states... + legal_moves = self.__get_legal_moves(s) + p = np.zeros(len(ACT_DICT)) + p[list(legal_moves)] = self.eps / (len(legal_moves)) + if a in legal_moves: + p[a] += 1.0-self.eps + else: + #p = np.array([1.0,0,0,0,0]) # NOOP + p[ACT_NOOP] += 1.0-self.eps + return p + + def __get_legal_moves(self, s): + xy = np.array(self.gs.idx_to_xy(s)) + moves = {move for move in ACT_DICT if not self.gs.out_of_bounds(xy+ACT_DICT[move]) + and self.gs[xy+ACT_DICT[move]] != WALL} + moves.add(ACT_NOOP) + return moves + + +class RewardFunction(object): + def __init__(self, rew_map=None, default=0): + if rew_map is None: + rew_map = { + REWARD: 1.0, + REWARD2: 2.0, + REWARD3: 4.0, + REWARD4: 8.0, + LAVA: -100.0, + } + self.default = default + self.rew_map = rew_map + + def __call__(self, gridspec, s, a, ns): + val = gridspec[gridspec.idx_to_xy(s)] + if val in self.rew_map: + return self.rew_map[val] + return self.default + + +class GridEnv(gym.Env): + def __init__(self, gridspec, + tiles=TILES, + rew_fn=None, + teps=0.0, + max_timesteps=None, + rew_map=None, + terminal_states=None, + default_rew=0): + self.num_states = len(gridspec) + self.num_actions = 5 + self._env_args = {'teps': teps, 'max_timesteps': max_timesteps} + self.gs = gridspec + self.model = TransitionModel(gridspec, eps=teps) + self.terminal_states = terminal_states + if rew_fn is None: + rew_fn = RewardFunction(rew_map=rew_map, default=default_rew) + self.rew_fn = rew_fn + self.possible_tiles = tiles + self.max_timesteps = max_timesteps + self._timestep = 0 + self._true_q = None # q_vals for debugging + super(GridEnv, self).__init__() + + def get_transitions(self, s, a): + tile_type = self.gs[self.gs.idx_to_xy(s)] + if tile_type == LAVA: # Lava gets you stuck + return {s: 1.0} + + aprobs = self.model.get_aprobs(s, a) + t_dict = {} + for sa in range(5): + if aprobs[sa] > 0: + next_s = self.gs.idx_to_xy(s) + ACT_DICT[sa] + next_s_idx = self.gs.xy_to_idx(next_s) + t_dict[next_s_idx] = t_dict.get(next_s_idx, 0.0) + aprobs[sa] + return t_dict + + + def step_stateless(self, s, a, verbose=False): + aprobs = self.model.get_aprobs(s, a) + samp_a = np.random.choice(range(5), p=aprobs) + + next_s = self.gs.idx_to_xy(s) + ACT_DICT[samp_a] + tile_type = self.gs[self.gs.idx_to_xy(s)] + if tile_type == LAVA: # Lava gets you stuck + next_s = self.gs.idx_to_xy(s) + + next_s_idx = self.gs.xy_to_idx(next_s) + rew = self.rew_fn(self.gs, s, samp_a, next_s_idx) + + if verbose: + print('Act: %s. Act Executed: %s' % (ACT_TO_STR[a], ACT_TO_STR[samp_a])) + return next_s_idx, rew + + def step(self, a, verbose=False): + ns, r = self.step_stateless(self.__state, a, verbose=verbose) + traj_infos = {} + self.__state = ns + obs = ns #flat_to_one_hot(ns, len(self.gs)) + + done = False + self._timestep += 1 + if self.max_timesteps is not None: + if self._timestep >= self.max_timesteps: + done = True + return obs, r, done, traj_infos + + def reset(self): + start_idxs = np.array(np.where(self.gs.spec == START)).T + start_idx = start_idxs[np.random.randint(0, start_idxs.shape[0])] + start_idx = self.gs.xy_to_idx(start_idx) + self.__state =start_idx + self._timestep = 0 + return start_idx #flat_to_one_hot(start_idx, len(self.gs)) + + def render(self, close=False, ostream=sys.stdout): + if close: + return + + state = self.__state + ostream.write('-'*(self.gs.width+2)+'\n') + for h in range(self.gs.height): + ostream.write('|') + for w in range(self.gs.width): + if self.gs.xy_to_idx((w,h)) == state: + ostream.write('*') + else: + val = self.gs[w, h] + ostream.write(RENDER_DICT[val]) + ostream.write('|\n') + ostream.write('-' * (self.gs.width + 2)+'\n') + + @property + def action_space(self): + return gym.spaces.Discrete(5) + + @property + def observation_space(self): + dO = len(self.gs) + #return gym.spaces.Box(0,1,shape=dO) + return gym.spaces.Discrete(dO) + + def transition_matrix(self): + """Constructs this environment's transition matrix. + + Returns: + A dS x dA x dS array where the entry transition_matrix[s, a, ns] + corrsponds to the probability of transitioning into state ns after taking + action a from state s. + """ + ds = self.num_states + da = self.num_actions + transition_matrix = np.zeros((ds, da, ds)) + for s in range(ds): + for a in range(da): + transitions = self.get_transitions(s,a) + for next_s in transitions: + transition_matrix[s, a, next_s] = transitions[next_s] + return transition_matrix + + def reward_matrix(self): + """Constructs this environment's reward matrix. + + Returns: + A dS x dA x dS numpy array where the entry reward_matrix[s, a, ns] + reward given to an agent when transitioning into state ns after taking + action s from state s. + """ + ds = self.num_states + da = self.num_actions + rew_matrix = np.zeros((ds, da, ds)) + for s in range(ds): + for a in range(da): + for ns in range(ds): + rew_matrix[s, a, ns] = self.rew_fn(self.gs, s, a, ns) + return rew_matrix diff --git a/d4rl/d4rl/pointmaze/gridcraft/grid_spec.py b/d4rl/d4rl/pointmaze/gridcraft/grid_spec.py new file mode 100644 index 0000000..fbcbca6 --- /dev/null +++ b/d4rl/d4rl/pointmaze/gridcraft/grid_spec.py @@ -0,0 +1,163 @@ +import numpy as np + + +EMPTY = 110 +WALL = 111 +START = 112 +REWARD = 113 +OUT_OF_BOUNDS = 114 +REWARD2 = 115 +REWARD3 = 116 +REWARD4 = 117 +LAVA = 118 +GOAL = 119 + +TILES = {EMPTY, WALL, START, REWARD, REWARD2, REWARD3, REWARD4, LAVA, GOAL} + +STR_MAP = { + 'O': EMPTY, + '#': WALL, + 'S': START, + 'R': REWARD, + '2': REWARD2, + '3': REWARD3, + '4': REWARD4, + 'G': GOAL, + 'L': LAVA +} + +RENDER_DICT = {v:k for k, v in STR_MAP.items()} +RENDER_DICT[EMPTY] = ' ' +RENDER_DICT[START] = ' ' + + + +def spec_from_string(s, valmap=STR_MAP): + if s.endswith('\\'): + s = s[:-1] + rows = s.split('\\') + rowlens = np.array([len(row) for row in rows]) + assert np.all(rowlens == rowlens[0]) + w, h = len(rows), len(rows[0])#len(rows[0]), len(rows) + + gs = GridSpec(w, h) + for i in range(w): + for j in range(h): + gs[i,j] = valmap[rows[i][j]] + return gs + + +def spec_from_sparse_locations(w, h, tile_to_locs): + """ + + Example usage: + >> spec_from_sparse_locations(10, 10, {START: [(0,0)], REWARD: [(7,8), (8,8)]}) + + """ + gs = GridSpec(w, h) + for tile_type in tile_to_locs: + locs = np.array(tile_to_locs[tile_type]) + for i in range(locs.shape[0]): + gs[tuple(locs[i])] = tile_type + return gs + + +def local_spec(map, xpnt): + """ + >>> local_spec("yOy\\\\Oxy", xpnt=(5,5)) + array([[4, 4], + [6, 4], + [6, 5]]) + """ + Y = 0; X=1; O=2 + valmap={ + 'y': Y, + 'x': X, + 'O': O + } + gs = spec_from_string(map, valmap=valmap) + ys = gs.find(Y) + x = gs.find(X) + result = ys-x + np.array(xpnt) + return result + + + +class GridSpec(object): + def __init__(self, w, h): + self.__data = np.zeros((w, h), dtype=np.int32) + self.__w = w + self.__h = h + + def __setitem__(self, key, val): + self.__data[key] = val + + def __getitem__(self, key): + if self.out_of_bounds(key): + raise NotImplementedError("Out of bounds:"+str(key)) + return self.__data[tuple(key)] + + def out_of_bounds(self, wh): + """ Return true if x, y is out of bounds """ + w, h = wh + if w<0 or w>=self.__w: + return True + if h < 0 or h >= self.__h: + return True + return False + + def get_neighbors(self, k, xy=False): + """ Return values of up, down, left, and right tiles """ + if not xy: + k = self.idx_to_xy(k) + offsets = [np.array([0,-1]), np.array([0,1]), + np.array([-1,0]), np.array([1,0])] + neighbors = \ + [self[k+offset] if (not self.out_of_bounds(k+offset)) else OUT_OF_BOUNDS for offset in offsets ] + return neighbors + + def get_value(self, k, xy=False): + """ Return values of up, down, left, and right tiles """ + if not xy: + k = self.idx_to_xy(k) + return self[k] + + def find(self, value): + return np.array(np.where(self.spec == value)).T + + @property + def spec(self): + return self.__data + + @property + def width(self): + return self.__w + + def __len__(self): + return self.__w*self.__h + + @property + def height(self): + return self.__h + + def idx_to_xy(self, idx): + if hasattr(idx, '__len__'): # array + x = idx % self.__w + y = np.floor(idx/self.__w).astype(np.int32) + xy = np.c_[x,y] + return xy + else: + return np.array([ idx % self.__w, int(np.floor(idx/self.__w))]) + + def xy_to_idx(self, key): + shape = np.array(key).shape + if len(shape) == 1: + return key[0] + key[1]*self.__w + elif len(shape) == 2: + return key[:,0] + key[:,1]*self.__w + else: + raise NotImplementedError() + + def __hash__(self): + data = (self.__w, self.__h) + tuple(self.__data.reshape([-1]).tolist()) + return hash(data) diff --git a/d4rl/d4rl/pointmaze/gridcraft/utils.py b/d4rl/d4rl/pointmaze/gridcraft/utils.py new file mode 100644 index 0000000..ea12463 --- /dev/null +++ b/d4rl/d4rl/pointmaze/gridcraft/utils.py @@ -0,0 +1,35 @@ +import numpy as np + +def flat_to_one_hot(val, ndim): + """ + + >>> flat_to_one_hot(2, ndim=4) + array([ 0., 0., 1., 0.]) + >>> flat_to_one_hot(4, ndim=5) + array([ 0., 0., 0., 0., 1.]) + >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) + array([[ 0., 0., 1., 0., 0.], + [ 0., 0., 0., 0., 1.], + [ 0., 0., 0., 1., 0.]]) + """ + shape =np.array(val).shape + v = np.zeros(shape + (ndim,)) + if len(shape) == 1: + v[np.arange(shape[0]), val] = 1.0 + else: + v[val] = 1.0 + return v + +def one_hot_to_flat(val): + """ + >>> one_hot_to_flat(np.array([0,0,0,0,1])) + 4 + >>> one_hot_to_flat(np.array([0,0,1,0])) + 2 + >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) + array([2, 0, 1]) + """ + idxs = np.array(np.where(val == 1.0))[-1] + if len(val.shape) == 1: + return int(idxs) + return idxs \ No newline at end of file diff --git a/d4rl/d4rl/pointmaze/gridcraft/wrappers.py b/d4rl/d4rl/pointmaze/gridcraft/wrappers.py new file mode 100644 index 0000000..e134569 --- /dev/null +++ b/d4rl/d4rl/pointmaze/gridcraft/wrappers.py @@ -0,0 +1,120 @@ +import numpy as np +from d4rl.pointmaze.gridcraft.grid_env import REWARD, GridEnv +from d4rl.pointmaze.gridcraft.wrappers import ObsWrapper +from gym.spaces import Box + + +class GridObsWrapper(ObsWrapper): + def __init__(self, env): + super(GridObsWrapper, self).__init__(env) + + def render(self): + self.env.render() + + + +class EyesWrapper(ObsWrapper): + def __init__(self, env, range=4, types=(REWARD,), angle_thresh=0.8): + super(EyesWrapper, self).__init__(env) + self.types = types + self.range = range + self.angle_thresh = angle_thresh + + eyes_low = np.ones(5*len(types)) + eyes_high = np.ones(5*len(types)) + low = np.r_[env.observation_space.low, eyes_low] + high = np.r_[env.observation_space.high, eyes_high] + self.__observation_space = Box(low, high) + + def wrap_obs(self, obs, info=None): + gs = self.env.gs # grid spec + xy = gs.idx_to_xy(self.env.obs_to_state(obs)) + #xy = np.array([x, y]) + + extra_obs = [] + for tile_type in self.types: + idxs = gs.find(tile_type).astype(np.float32) # N x 2 + # gather all idxs that are close + diffs = idxs-np.expand_dims(xy, axis=0) + dists = np.linalg.norm(diffs, axis=1) + valid_idxs = np.where(dists <= self.range)[0] + if len(valid_idxs) == 0: + eye_data = np.array([0,0,0,0,0], dtype=np.float32) + else: + diffs = diffs[valid_idxs, :] + dists = dists[valid_idxs]+1e-6 + cosines = diffs[:,0]/dists + cosines = np.r_[cosines, 0] + sines = diffs[:,1]/dists + sines = np.r_[sines, 0] + on_target = 0.0 + if np.any(dists<=1.0): + on_target = 1.0 + eye_data = np.abs(np.array([on_target, np.max(cosines), np.min(cosines), np.max(sines), np.min(sines)])) + eye_data[np.where(eye_data<=self.angle_thresh)] = 0 + extra_obs.append(eye_data) + extra_obs = np.concatenate(extra_obs) + obs = np.r_[obs, extra_obs] + #if np.any(np.isnan(obs)): + # import pdb; pdb.set_trace() + return obs + + def unwrap_obs(self, obs, info=None): + if len(obs.shape) == 1: + return obs[:-5*len(self.types)] + else: + return obs[:,:-5*len(self.types)] + + @property + def observation_space(self): + return self.__observation_space + + +""" +class CoordinateWiseWrapper(GridObsWrapper): + def __init__(self, env): + assert isinstance(env, GridEnv) + super(CoordinateWiseWrapper, self).__init__(env) + self.gs = env.gs + self.dO = self.gs.width+self.gs.height + + self.__observation_space = Box(0, 1, self.dO) + + def wrap_obs(self, obs, info=None): + state = one_hot_to_flat(obs) + xy = self.gs.idx_to_xy(state) + x = flat_to_one_hot(xy[0], self.gs.width) + y = flat_to_one_hot(xy[1], self.gs.height) + obs = np.r_[x, y] + return obs + + def unwrap_obs(self, obs, info=None): + + if len(obs.shape) == 1: + x = obs[:self.gs.width] + y = obs[self.gs.width:] + x = one_hot_to_flat(x) + y = one_hot_to_flat(y) + state = self.gs.xy_to_idx(np.c_[x,y]) + return flat_to_one_hot(state, self.dO) + else: + raise NotImplementedError() +""" + + +class RandomObsWrapper(GridObsWrapper): + def __init__(self, env, dO): + assert isinstance(env, GridEnv) + super(RandomObsWrapper, self).__init__(env) + self.gs = env.gs + self.dO = dO + self.obs_matrix = np.random.randn(self.dO, len(self.gs)) + self.__observation_space = Box(np.min(self.obs_matrix), np.max(self.obs_matrix), + shape=(self.dO,), dtype=np.float32) + + def wrap_obs(self, obs, info=None): + return np.inner(self.obs_matrix, obs) + + def unwrap_obs(self, obs, info=None): + raise NotImplementedError() + diff --git a/d4rl/d4rl/pointmaze/maze_model.py b/d4rl/d4rl/pointmaze/maze_model.py new file mode 100644 index 0000000..cdfcf2f --- /dev/null +++ b/d4rl/d4rl/pointmaze/maze_model.py @@ -0,0 +1,245 @@ +""" A pointmass maze env.""" +from gym.envs.mujoco import mujoco_env +from gym import utils +from d4rl import offline_env +from d4rl.pointmaze.dynamic_mjc import MJCModel +import numpy as np +import random + + +WALL = 10 +EMPTY = 11 +GOAL = 12 + + +def parse_maze(maze_str): + lines = maze_str.strip().split('\\') + width, height = len(lines), len(lines[0]) + maze_arr = np.zeros((width, height), dtype=np.int32) + for w in range(width): + for h in range(height): + tile = lines[w][h] + if tile == '#': + maze_arr[w][h] = WALL + elif tile == 'G': + maze_arr[w][h] = GOAL + elif tile == ' ' or tile == 'O' or tile == '0': + maze_arr[w][h] = EMPTY + else: + raise ValueError('Unknown tile type: %s' % tile) + return maze_arr + + +def point_maze(maze_str): + maze_arr = parse_maze(maze_str) + + mjcmodel = MJCModel('point_maze') + mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local") + mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler") + default = mjcmodel.root.default() + default.joint(damping=1, limited='false') + default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1") + + asset = mjcmodel.root.asset() + asset.texture(type="2d",name="groundplane",builtin="checker",rgb1="0.2 0.3 0.4",rgb2="0.1 0.2 0.3",width=100,height=100) + asset.texture(name="skybox",type="skybox",builtin="gradient",rgb1=".4 .6 .8",rgb2="0 0 0", + width="800",height="800",mark="random",markrgb="1 1 1") + asset.material(name="groundplane",texture="groundplane",texrepeat="20 20") + asset.material(name="wall",rgba=".7 .5 .3 1") + asset.material(name="target",rgba=".6 .3 .3 1") + + visual = mjcmodel.root.visual() + visual.headlight(ambient=".4 .4 .4",diffuse=".8 .8 .8",specular="0.1 0.1 0.1") + visual.map(znear=.01) + visual.quality(shadowsize=2048) + + worldbody = mjcmodel.root.worldbody() + worldbody.geom(name='ground',size="40 40 0.25",pos="0 0 -0.1",type="plane",contype=1,conaffinity=0,material="groundplane") + + particle = worldbody.body(name='particle', pos=[1.2,1.2,0]) + particle.geom(name='particle_geom', type='sphere', size=0.1, rgba='0.0 0.0 1.0 0.0', contype=1) + particle.site(name='particle_site', pos=[0.0,0.0,0], size=0.2, rgba='0.3 0.6 0.3 1') + particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0]) + particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0]) + + worldbody.site(name='target_site', pos=[0.0,0.0,0], size=0.2, material='target') + + width, height = maze_arr.shape + for w in range(width): + for h in range(height): + if maze_arr[w,h] == WALL: + worldbody.geom(conaffinity=1, + type='box', + name='wall_%d_%d'%(w,h), + material='wall', + pos=[w+1.0,h+1.0,0], + size=[0.5,0.5,0.2]) + + actuator = mjcmodel.root.actuator() + actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100) + actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True, gear=100) + + return mjcmodel + + +LARGE_MAZE = \ + "############\\"+\ + "#OOOO#OOOOO#\\"+\ + "#O##O#O#O#O#\\"+\ + "#OOOOOO#OOO#\\"+\ + "#O####O###O#\\"+\ + "#OO#O#OOOOO#\\"+\ + "##O#O#O#O###\\"+\ + "#OO#OOO#OGO#\\"+\ + "############" + +LARGE_MAZE_EVAL = \ + "############\\"+\ + "#OO#OOO#OGO#\\"+\ + "##O###O#O#O#\\"+\ + "#OO#O#OOOOO#\\"+\ + "#O##O#OO##O#\\"+\ + "#OOOOOO#OOO#\\"+\ + "#O##O#O#O###\\"+\ + "#OOOO#OOOOO#\\"+\ + "############" + +MEDIUM_MAZE = \ + '########\\'+\ + '#OO##OO#\\'+\ + '#OO#OOO#\\'+\ + '##OOO###\\'+\ + '#OO#OOO#\\'+\ + '#O#OO#O#\\'+\ + '#OOO#OG#\\'+\ + "########" + +MEDIUM_MAZE_EVAL = \ + '########\\'+\ + '#OOOOOG#\\'+\ + '#O#O##O#\\'+\ + '#OOOO#O#\\'+\ + '###OO###\\'+\ + '#OOOOOO#\\'+\ + '#OO##OO#\\'+\ + "########" + +SMALL_MAZE = \ + "######\\"+\ + "#OOOO#\\"+\ + "#O##O#\\"+\ + "#OOOO#\\"+\ + "######" + +U_MAZE = \ + "#####\\"+\ + "#GOO#\\"+\ + "###O#\\"+\ + "#OOO#\\"+\ + "#####" + +U_MAZE_EVAL = \ + "#####\\"+\ + "#OOG#\\"+\ + "#O###\\"+\ + "#OOO#\\"+\ + "#####" + +OPEN = \ + "#######\\"+\ + "#OOOOO#\\"+\ + "#OOGOO#\\"+\ + "#OOOOO#\\"+\ + "#######" + + +class MazeEnv(mujoco_env.MujocoEnv, utils.EzPickle, offline_env.OfflineEnv): + def __init__(self, + maze_spec=U_MAZE, + reward_type='dense', + reset_target=False, + **kwargs): + offline_env.OfflineEnv.__init__(self, **kwargs) + + self.reset_target = reset_target + self.str_maze_spec = maze_spec + self.maze_arr = parse_maze(maze_spec) + self.reward_type = reward_type + self.reset_locations = list(zip(*np.where(self.maze_arr == EMPTY))) + self.reset_locations.sort() + + self._target = np.array([0.0,0.0]) + + model = point_maze(maze_spec) + with model.asfile() as f: + mujoco_env.MujocoEnv.__init__(self, model_path=f.name, frame_skip=1) + utils.EzPickle.__init__(self) + + # Set the default goal (overriden by a call to set_target) + # Try to find a goal if it exists + self.goal_locations = list(zip(*np.where(self.maze_arr == GOAL))) + if len(self.goal_locations) == 1: + self.set_target(self.goal_locations[0]) + elif len(self.goal_locations) > 1: + raise ValueError("More than 1 goal specified!") + else: + # If no goal, use the first empty tile + self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype)) + self.empty_and_goal_locations = self.reset_locations + self.goal_locations + + def step(self, action): + action = np.clip(action, -1.0, 1.0) + self.clip_velocity() + self.do_simulation(action, self.frame_skip) + self.set_marker() + ob = self._get_obs() + if self.reward_type == 'sparse': + reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0 + elif self.reward_type == 'dense': + reward = np.exp(-np.linalg.norm(ob[0:2] - self._target)) + else: + raise ValueError('Unknown reward type %s' % self.reward_type) + done = False + return ob, reward, done, {} + + def _get_obs(self): + return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel() + + def get_target(self): + return self._target + + def set_target(self, target_location=None): + if target_location is None: + idx = self.np_random.choice(len(self.empty_and_goal_locations)) + reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) + target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) + self._target = target_location + + def set_marker(self): + self.data.site_xpos[self.model.site_name2id('target_site')] = np.array([self._target[0]+1, self._target[1]+1, 0.0]) + + def clip_velocity(self): + qvel = np.clip(self.sim.data.qvel, -5.0, 5.0) + self.set_state(self.sim.data.qpos, qvel) + + def reset_model(self): + idx = self.np_random.choice(len(self.empty_and_goal_locations)) + reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) + qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 + self.set_state(qpos, qvel) + if self.reset_target: + self.set_target() + return self._get_obs() + + def reset_to_location(self, location): + self.sim.reset() + reset_location = np.array(location).astype(self.observation_space.dtype) + qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + pass + diff --git a/d4rl/d4rl/pointmaze/q_iteration.py b/d4rl/d4rl/pointmaze/q_iteration.py new file mode 100644 index 0000000..a849396 --- /dev/null +++ b/d4rl/d4rl/pointmaze/q_iteration.py @@ -0,0 +1,109 @@ +""" +Use q-iteration to solve for an optimal policy + +Usage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus) +""" +import numpy as np +from scipy.special import logsumexp as sp_lse + +def softmax(q, alpha=1.0): + q = (1.0/alpha)*q + q = q-np.max(q) + probs = np.exp(q) + probs = probs/np.sum(probs) + return probs + +def logsumexp(q, alpha=1.0, axis=1): + if alpha == 0: + return np.max(q, axis=axis) + return alpha*sp_lse((1.0/alpha)*q, axis=axis) + + +def get_policy(q_fn, ent_wt=1.0): + v_rew = logsumexp(q_fn, alpha=ent_wt) + adv_rew = q_fn - np.expand_dims(v_rew, axis=1) + if ent_wt == 0: + pol_probs = adv_rew + pol_probs[pol_probs >= 0 ] = 1.0 + pol_probs[pol_probs < 0 ] = 0.0 + else: + pol_probs = np.exp((1.0/ent_wt)*adv_rew) + pol_probs /= np.sum(pol_probs, axis=1, keepdims=True) + assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs) + return pol_probs + + +def softq_iteration(env, transition_matrix=None, reward_matrix=None, num_itrs=50, discount=0.99, ent_wt=0.1, warmstart_q=None, policy=None): + """ + Perform tabular soft Q-iteration + """ + dim_obs = env.num_states + dim_act = env.num_actions + if reward_matrix is None: + reward_matrix = env.reward_matrix() + reward_matrix = reward_matrix[:,:,0] + + if warmstart_q is None: + q_fn = np.zeros((dim_obs, dim_act)) + else: + q_fn = warmstart_q + + if transition_matrix is None: + t_matrix = env.transition_matrix() + else: + t_matrix = transition_matrix + + for k in range(num_itrs): + if policy is None: + v_fn = logsumexp(q_fn, alpha=ent_wt) + else: + v_fn = np.sum((q_fn - ent_wt*np.log(policy))*policy, axis=1) + new_q = reward_matrix + discount*t_matrix.dot(v_fn) + q_fn = new_q + return q_fn + + +def q_iteration(env, **kwargs): + return softq_iteration(env, ent_wt=0.0, **kwargs) + + +def compute_visitation(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): + pol_probs = get_policy(q_fn, ent_wt=ent_wt) + + dim_obs = env.num_states + dim_act = env.num_actions + state_visitation = np.zeros((dim_obs, 1)) + for (state, prob) in env.initial_state_distribution.items(): + state_visitation[state] = prob + t_matrix = env.transition_matrix() # S x A x S + sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) + + for i in range(env_time_limit): + sa_visit = state_visitation * pol_probs + # sa_visit_t[:, :, i] = (discount ** i) * sa_visit + sa_visit_t[:, :, i] = sa_visit + # sum-out (SA)S + new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) + state_visitation = np.expand_dims(new_state_visitation, axis=1) + return np.sum(sa_visit_t, axis=2) / float(env_time_limit) + + +def compute_occupancy(env, q_fn, ent_wt=1.0, env_time_limit=50, discount=1.0): + pol_probs = get_policy(q_fn, ent_wt=ent_wt) + + dim_obs = env.num_states + dim_act = env.num_actions + state_visitation = np.zeros((dim_obs, 1)) + for (state, prob) in env.initial_state_distribution.items(): + state_visitation[state] = prob + t_matrix = env.transition_matrix() # S x A x S + sa_visit_t = np.zeros((dim_obs, dim_act, env_time_limit)) + + for i in range(env_time_limit): + sa_visit = state_visitation * pol_probs + sa_visit_t[:, :, i] = (discount ** i) * sa_visit + # sa_visit_t[:, :, i] = sa_visit + # sum-out (SA)S + new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) + state_visitation = np.expand_dims(new_state_visitation, axis=1) + return np.sum(sa_visit_t, axis=2) #/ float(env_time_limit) diff --git a/d4rl/d4rl/pointmaze/waypoint_controller.py b/d4rl/d4rl/pointmaze/waypoint_controller.py new file mode 100644 index 0000000..d7601a4 --- /dev/null +++ b/d4rl/d4rl/pointmaze/waypoint_controller.py @@ -0,0 +1,109 @@ +import numpy as np +from d4rl.pointmaze import q_iteration +from d4rl.pointmaze.gridcraft import grid_env +from d4rl.pointmaze.gridcraft import grid_spec + + +ZEROS = np.zeros((2,), dtype=np.float32) +ONES = np.zeros((2,), dtype=np.float32) + + +class WaypointController(object): + def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0): + self.maze_str = maze_str + self._target = -1000 * ONES + + self.p_gain = p_gain + self.d_gain = d_gain + self.solve_thresh = solve_thresh + self.vel_thresh = 0.1 + + self._waypoint_idx = 0 + self._waypoints = [] + self._waypoint_prev_loc = ZEROS + + self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str)) + + def current_waypoint(self): + return self._waypoints[self._waypoint_idx] + + def get_action(self, location, velocity, target): + if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3: + #print('New target!', target, 'old:', self._target) + self._new_target(location, target) + + dist = np.linalg.norm(location - self._target) + vel = self._waypoint_prev_loc - location + vel_norm = np.linalg.norm(vel) + task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh) + + if task_not_solved: + next_wpnt = self._waypoints[self._waypoint_idx] + else: + next_wpnt = self._target + + # Compute control + prop = next_wpnt - location + action = self.p_gain * prop + self.d_gain * velocity + + dist_next_wpnt = np.linalg.norm(location - next_wpnt) + if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm 1: + raise ValueError("More than 1 goal specified!") + else: + # If no goal, use the first empty tile + self.set_target(np.array(self.reset_locations[0]).astype(self.observation_space.dtype)) + self.empty_and_goal_locations = self.reset_locations + self.goal_locations + + def create_single_player_scene(self, bullet_client): + return scene_abstract.SingleRobotEmptyScene(bullet_client, gravity=9.8, timestep=0.0165, frame_skip=1) + + def reset(self): + if (self.stateId >= 0): + self._p.restoreState(self.stateId) + r = env_bases.MJCFBaseBulletEnv.reset(self) + if (self.stateId < 0): + self.stateId = self._p.saveState() + + self.reset_model() + ob = self.robot.calc_state() + return ob + + def step(self, action): + action = np.clip(action, -1.0, 1.0) + #self.clip_velocity() + self.robot.apply_action(action) + self.scene.global_step() + ob = self.robot.calc_state() + if self.reward_type == 'sparse': + reward = 1.0 if np.linalg.norm(ob[0:2] - self._target) <= 0.5 else 0.0 + elif self.reward_type == 'dense': + reward = np.exp(-np.linalg.norm(ob[0:2] - self._target)) + else: + raise ValueError('Unknown reward type %s' % self.reward_type) + done = False + self.HUD(ob, action, done) + return ob, reward, done, {} + + def camera_adjust(self): + qpos = self.robot.qpos + x = qpos[0] + y = qpos[1] + self.camera.move_and_look_at(x, y, 1.4, x, y, 1.0) + + def get_target(self): + return self._target + + def set_target(self, target_location=None): + if target_location is None: + idx = self.np_random.choice(len(self.empty_and_goal_locations)) + reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) + target_location = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) + self._target = target_location + + def clip_velocity(self): + qvel = np.clip(self.robot.qvel, -5.0, 5.0) + self.robot.set_state(self.robot.qpos, qvel) + + def reset_model(self): + idx = self.np_random.choice(len(self.empty_and_goal_locations)) + reset_location = np.array(self.empty_and_goal_locations[idx]).astype(self.observation_space.dtype) + qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) + qvel = self.np_random.randn(2) * .1 + self.robot.set_state(qpos, qvel) + if self.reset_target: + self.set_target() + return self.robot.get_obs() + + def reset_to_location(self, location): + self.sim.reset() + reset_location = np.array(location).astype(self.observation_space.dtype) + qpos = reset_location + self.np_random.uniform(low=-.1, high=.1, size=2) + qvel = self.np_random.randn(2) * .1 + self.robot.set_state(qpos, qvel) + return self.robot.get_obs() + diff --git a/d4rl/d4rl/pointmaze_bullet/bullet_robot.py b/d4rl/d4rl/pointmaze_bullet/bullet_robot.py new file mode 100644 index 0000000..a14d9dd --- /dev/null +++ b/d4rl/d4rl/pointmaze_bullet/bullet_robot.py @@ -0,0 +1,126 @@ +import os +import pybullet +from pybullet_envs import robot_bases + +class MJCFBasedRobot(robot_bases.XmlBasedRobot): + """ + Base class for mujoco .xml based agents. + """ + + def __init__(self, model_xml, robot_name, action_dim, obs_dim, self_collision=True): + robot_bases.XmlBasedRobot.__init__(self, robot_name, action_dim, obs_dim, self_collision) + self.model_xml = model_xml + self.doneLoading = 0 + + def reset(self, bullet_client): + + self._p = bullet_client + #print("Created bullet_client with id=", self._p._client) + if (self.doneLoading == 0): + self.ordered_joints = [] + self.doneLoading = 1 + if self.self_collision: + self.objects = self._p.loadMJCF(self.model_xml, + flags=pybullet.URDF_USE_SELF_COLLISION | + pybullet.URDF_USE_SELF_COLLISION_EXCLUDE_ALL_PARENTS | + pybullet.URDF_GOOGLEY_UNDEFINED_COLORS ) + self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene( + self._p, self.objects) + else: + self.objects = self._p.loadMJCF(self.model_xml, flags = pybullet.URDF_GOOGLEY_UNDEFINED_COLORS) + self.parts, self.jdict, self.ordered_joints, self.robot_body = self.addToScene( + self._p, self.objects) + self.robot_specific_reset(self._p) + + s = self.calc_state( + ) # optimization: calc_state() can calculate something in self.* for calc_potential() to use + + return s + + def calc_potential(self): + return 0 + + +class WalkerBase(MJCFBasedRobot): + + def __init__(self, fn, robot_name, action_dim, obs_dim, power): + MJCFBasedRobot.__init__(self, fn, robot_name, action_dim, obs_dim) + self.power = power + self.camera_x = 0 + self.start_pos_x, self.start_pos_y, self.start_pos_z = 0, 0, 0 + self.walk_target_x = 1e3 # kilometer away + self.walk_target_y = 0 + self.body_xyz = [0, 0, 0] + + def robot_specific_reset(self, bullet_client): + self._p = bullet_client + for j in self.ordered_joints: + j.reset_current_position(self.np_random.uniform(low=-0.1, high=0.1), 0) + + self.feet = [self.parts[f] for f in self.foot_list] + self.feet_contact = np.array([0.0 for f in self.foot_list], dtype=np.float32) + self.scene.actor_introduce(self) + self.initial_z = None + + def apply_action(self, a): + assert (np.isfinite(a).all()) + for n, j in enumerate(self.ordered_joints): + j.set_motor_torque(self.power * j.power_coef * float(np.clip(a[n], -1, +1))) + + def calc_state(self): + j = np.array([j.current_relative_position() for j in self.ordered_joints], + dtype=np.float32).flatten() + # even elements [0::2] position, scaled to -1..+1 between limits + # odd elements [1::2] angular speed, scaled to show -1..+1 + self.joint_speeds = j[1::2] + self.joints_at_limit = np.count_nonzero(np.abs(j[0::2]) > 0.99) + + body_pose = self.robot_body.pose() + parts_xyz = np.array([p.pose().xyz() for p in self.parts.values()]).flatten() + self.body_xyz = (parts_xyz[0::3].mean(), parts_xyz[1::3].mean(), body_pose.xyz()[2] + ) # torso z is more informative than mean z + self.body_real_xyz = body_pose.xyz() + self.body_rpy = body_pose.rpy() + z = self.body_xyz[2] + if self.initial_z == None: + self.initial_z = z + r, p, yaw = self.body_rpy + self.walk_target_theta = np.arctan2(self.walk_target_y - self.body_xyz[1], + self.walk_target_x - self.body_xyz[0]) + self.walk_target_dist = np.linalg.norm( + [self.walk_target_y - self.body_xyz[1], self.walk_target_x - self.body_xyz[0]]) + angle_to_target = self.walk_target_theta - yaw + + rot_speed = np.array([[np.cos(-yaw), -np.sin(-yaw), 0], [np.sin(-yaw), + np.cos(-yaw), 0], [0, 0, 1]]) + vx, vy, vz = np.dot(rot_speed, + self.robot_body.speed()) # rotate speed back to body point of view + + more = np.array( + [ + z - self.initial_z, + np.sin(angle_to_target), + np.cos(angle_to_target), + 0.3 * vx, + 0.3 * vy, + 0.3 * vz, # 0.3 is just scaling typical speed into -1..+1, no physical sense here + r, + p + ], + dtype=np.float32) + return np.clip(np.concatenate([more] + [j] + [self.feet_contact]), -5, +5) + + def calc_potential(self): + # progress in potential field is speed*dt, typical speed is about 2-3 meter per second, this potential will change 2-3 per frame (not per second), + # all rewards have rew/frame units and close to 1.0 + debugmode = 0 + if (debugmode): + print("calc_potential: self.walk_target_dist") + print(self.walk_target_dist) + print("self.scene.dt") + print(self.scene.dt) + print("self.scene.frame_skip") + print(self.scene.frame_skip) + print("self.scene.timestep") + print(self.scene.timestep) + return -self.walk_target_dist / self.scene.dt diff --git a/d4rl/d4rl/utils/__init__.py b/d4rl/d4rl/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/d4rl/d4rl/utils/dataset_utils.py b/d4rl/d4rl/utils/dataset_utils.py new file mode 100644 index 0000000..0d26b49 --- /dev/null +++ b/d4rl/d4rl/utils/dataset_utils.py @@ -0,0 +1,55 @@ +import h5py +import numpy as np + +class DatasetWriter(object): + def __init__(self, mujoco=False, goal=False): + self.mujoco = mujoco + self.goal = goal + self.data = self._reset_data() + self._num_samples = 0 + + def _reset_data(self): + data = {'observations': [], + 'actions': [], + 'terminals': [], + 'rewards': [], + } + if self.mujoco: + data['infos/qpos'] = [] + data['infos/qvel'] = [] + if self.goal: + data['infos/goal'] = [] + return data + + def __len__(self): + return self._num_samples + + def append_data(self, s, a, r, done, goal=None, mujoco_env_data=None): + self._num_samples += 1 + self.data['observations'].append(s) + self.data['actions'].append(a) + self.data['rewards'].append(r) + self.data['terminals'].append(done) + if self.goal: + self.data['infos/goal'].append(goal) + if self.mujoco: + self.data['infos/qpos'].append(mujoco_env_data.qpos.ravel().copy()) + self.data['infos/qvel'].append(mujoco_env_data.qvel.ravel().copy()) + + def write_dataset(self, fname, max_size=None, compression='gzip'): + np_data = {} + for k in self.data: + if k == 'terminals': + dtype = np.bool_ + else: + dtype = np.float32 + data = np.array(self.data[k], dtype=dtype) + if max_size is not None: + data = data[:max_size] + np_data[k] = data + + dataset = h5py.File(fname, 'w') + for k in np_data: + dataset.create_dataset(k, data=np_data[k], compression=compression) + dataset.close() + diff --git a/d4rl/d4rl/utils/quatmath.py b/d4rl/d4rl/utils/quatmath.py new file mode 100644 index 0000000..7fef129 --- /dev/null +++ b/d4rl/d4rl/utils/quatmath.py @@ -0,0 +1,164 @@ +import numpy as np +# For testing whether a number is close to zero +_FLOAT_EPS = np.finfo(np.float64).eps +_EPS4 = _FLOAT_EPS * 4.0 + + +def mulQuat(qa, qb): + res = np.zeros(4) + res[0] = qa[0]*qb[0] - qa[1]*qb[1] - qa[2]*qb[2] - qa[3]*qb[3] + res[1] = qa[0]*qb[1] + qa[1]*qb[0] + qa[2]*qb[3] - qa[3]*qb[2] + res[2] = qa[0]*qb[2] - qa[1]*qb[3] + qa[2]*qb[0] + qa[3]*qb[1] + res[3] = qa[0]*qb[3] + qa[1]*qb[2] - qa[2]*qb[1] + qa[3]*qb[0] + return res + +def negQuat(quat): + return np.array([quat[0], -quat[1], -quat[2], -quat[3]]) + +def quat2Vel(quat, dt=1): + axis = quat[1:].copy() + sin_a_2 = np.sqrt(np.sum(axis**2)) + axis = axis/(sin_a_2+1e-8) + speed = 2*np.arctan2(sin_a_2, quat[0])/dt + return speed, axis + +def quatDiff2Vel(quat1, quat2, dt): + neg = negQuat(quat1) + diff = mulQuat(quat2, neg) + return quat2Vel(diff, dt) + + +def axis_angle2quat(axis, angle): + c = np.cos(angle/2) + s = np.sin(angle/2) + return np.array([c, s*axis[0], s*axis[1], s*axis[2]]) + +def euler2mat(euler): + """ Convert Euler Angles to Rotation Matrix. See rotation.py for notes """ + euler = np.asarray(euler, dtype=np.float64) + assert euler.shape[-1] == 3, "Invalid shaped euler {}".format(euler) + + ai, aj, ak = -euler[..., 2], -euler[..., 1], -euler[..., 0] + si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) + ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) + cc, cs = ci * ck, ci * sk + sc, ss = si * ck, si * sk + + mat = np.empty(euler.shape[:-1] + (3, 3), dtype=np.float64) + mat[..., 2, 2] = cj * ck + mat[..., 2, 1] = sj * sc - cs + mat[..., 2, 0] = sj * cc + ss + mat[..., 1, 2] = cj * sk + mat[..., 1, 1] = sj * ss + cc + mat[..., 1, 0] = sj * cs - sc + mat[..., 0, 2] = -sj + mat[..., 0, 1] = cj * si + mat[..., 0, 0] = cj * ci + return mat + + +def euler2quat(euler): + """ Convert Euler Angles to Quaternions. See rotation.py for notes """ + euler = np.asarray(euler, dtype=np.float64) + assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler) + + ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2 + si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) + ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) + cc, cs = ci * ck, ci * sk + sc, ss = si * ck, si * sk + + quat = np.empty(euler.shape[:-1] + (4,), dtype=np.float64) + quat[..., 0] = cj * cc + sj * ss + quat[..., 3] = cj * sc - sj * cs + quat[..., 2] = -(cj * ss + sj * cc) + quat[..., 1] = cj * cs - sj * sc + return quat + + +def mat2euler(mat): + """ Convert Rotation Matrix to Euler Angles. See rotation.py for notes """ + mat = np.asarray(mat, dtype=np.float64) + assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) + + cy = np.sqrt(mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2]) + condition = cy > _EPS4 + euler = np.empty(mat.shape[:-1], dtype=np.float64) + euler[..., 2] = np.where(condition, + -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), + -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1])) + euler[..., 1] = np.where(condition, + -np.arctan2(-mat[..., 0, 2], cy), + -np.arctan2(-mat[..., 0, 2], cy)) + euler[..., 0] = np.where(condition, + -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), + 0.0) + return euler + + +def mat2quat(mat): + """ Convert Rotation Matrix to Quaternion. See rotation.py for notes """ + mat = np.asarray(mat, dtype=np.float64) + assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) + + Qxx, Qyx, Qzx = mat[..., 0, 0], mat[..., 0, 1], mat[..., 0, 2] + Qxy, Qyy, Qzy = mat[..., 1, 0], mat[..., 1, 1], mat[..., 1, 2] + Qxz, Qyz, Qzz = mat[..., 2, 0], mat[..., 2, 1], mat[..., 2, 2] + # Fill only lower half of symmetric matrix + K = np.zeros(mat.shape[:-2] + (4, 4), dtype=np.float64) + K[..., 0, 0] = Qxx - Qyy - Qzz + K[..., 1, 0] = Qyx + Qxy + K[..., 1, 1] = Qyy - Qxx - Qzz + K[..., 2, 0] = Qzx + Qxz + K[..., 2, 1] = Qzy + Qyz + K[..., 2, 2] = Qzz - Qxx - Qyy + K[..., 3, 0] = Qyz - Qzy + K[..., 3, 1] = Qzx - Qxz + K[..., 3, 2] = Qxy - Qyx + K[..., 3, 3] = Qxx + Qyy + Qzz + K /= 3.0 + # TODO: vectorize this -- probably could be made faster + q = np.empty(K.shape[:-2] + (4,)) + it = np.nditer(q[..., 0], flags=['multi_index']) + while not it.finished: + # Use Hermitian eigenvectors, values for speed + vals, vecs = np.linalg.eigh(K[it.multi_index]) + # Select largest eigenvector, reorder to w,x,y,z quaternion + q[it.multi_index] = vecs[[3, 0, 1, 2], np.argmax(vals)] + # Prefer quaternion with positive w + # (q * -1 corresponds to same rotation as q) + if q[it.multi_index][0] < 0: + q[it.multi_index] *= -1 + it.iternext() + return q + + +def quat2euler(quat): + """ Convert Quaternion to Euler Angles. See rotation.py for notes """ + return mat2euler(quat2mat(quat)) + + +def quat2mat(quat): + """ Convert Quaternion to Euler Angles. See rotation.py for notes """ + quat = np.asarray(quat, dtype=np.float64) + assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat) + + w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + Nq = np.sum(quat * quat, axis=-1) + s = 2.0 / Nq + X, Y, Z = x * s, y * s, z * s + wX, wY, wZ = w * X, w * Y, w * Z + xX, xY, xZ = x * X, x * Y, x * Z + yY, yZ, zZ = y * Y, y * Z, z * Z + + mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64) + mat[..., 0, 0] = 1.0 - (yY + zZ) + mat[..., 0, 1] = xY - wZ + mat[..., 0, 2] = xZ + wY + mat[..., 1, 0] = xY + wZ + mat[..., 1, 1] = 1.0 - (xX + zZ) + mat[..., 1, 2] = yZ - wX + mat[..., 2, 0] = xZ - wY + mat[..., 2, 1] = yZ + wX + mat[..., 2, 2] = 1.0 - (xX + yY) + return np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, np.eye(3)) \ No newline at end of file diff --git a/d4rl/d4rl/utils/visualize_env.py b/d4rl/d4rl/utils/visualize_env.py new file mode 100644 index 0000000..52265db --- /dev/null +++ b/d4rl/d4rl/utils/visualize_env.py @@ -0,0 +1,50 @@ +import gym +import d4rl +import click +import os +import gym +import numpy as np +import pickle +from mjrl.utils.gym_env import GymEnv +#from mjrl.policies.gaussian_mlp import MLP + +DESC = ''' +Helper script to visualize policy (in mjrl format).\n +USAGE:\n + Visualizes policy on the env\n + $ python visualize_policy.py --env_name door-v0 \n + $ python visualize_policy.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n +''' + +class RandomPolicy(object): + def __init__(self, env): + self.env = env + + def get_action(self, obs): + return [self.env.action_space.sample(), + {'evaluation': self.env.action_space.sample()}] + + +# MAIN ========================================================= +@click.command(help=DESC) +@click.option('--env_name', type=str, help='environment to load', required= True) +@click.option('--policy', type=str, help='absolute path of the policy file', default=None) +@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') +@click.option('--seed', type=int, help='seed for generating environment instances', default=123) +@click.option('--episodes', type=int, help='number of episodes to visualize', default=10) + +def main(env_name, policy, mode, seed, episodes): + e = GymEnv(env_name) + e.set_seed(seed) + """ + if policy is not None: + pi = pickle.load(open(policy, 'rb')) + else: + pi = MLP(e.spec, hidden_sizes=(32,32), seed=seed, init_log_std=-1.0) + """ + pi = RandomPolicy(e) + # render policy + e.visualize_policy(pi, num_episodes=episodes, horizon=e.horizon, mode=mode) + +if __name__ == '__main__': + main() diff --git a/d4rl/d4rl/utils/wrappers.py b/d4rl/d4rl/utils/wrappers.py new file mode 100644 index 0000000..f01b2c2 --- /dev/null +++ b/d4rl/d4rl/utils/wrappers.py @@ -0,0 +1,171 @@ +import numpy as np +import itertools +from gym import Env +from gym.spaces import Box +from gym.spaces import Discrete + +from collections import deque + + +class ProxyEnv(Env): + def __init__(self, wrapped_env): + self._wrapped_env = wrapped_env + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + return self._wrapped_env.reset(**kwargs) + + def step(self, action): + return self._wrapped_env.step(action) + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + def seed(self, seed=0): + return self._wrapped_env.seed(seed=seed) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self.wrapped_env.terminate() + + def __getattr__(self, attr): + if attr == '_wrapped_env': + raise AttributeError() + return getattr(self._wrapped_env, attr) + + def __getstate__(self): + """ + This is useful to override in case the wrapped env has some funky + __getstate__ that doesn't play well with overriding __getattr__. + + The main problematic case is/was gym's EzPickle serialization scheme. + :return: + """ + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return '{}({})'.format(type(self).__name__, self.wrapped_env) + + +class HistoryEnv(ProxyEnv, Env): + def __init__(self, wrapped_env, history_len): + super().__init__(wrapped_env) + self.history_len = history_len + + high = np.inf * np.ones( + self.history_len * self.observation_space.low.size) + low = -high + self.observation_space = Box(low=low, + high=high, + ) + self.history = deque(maxlen=self.history_len) + + def step(self, action): + state, reward, done, info = super().step(action) + self.history.append(state) + flattened_history = self._get_history().flatten() + return flattened_history, reward, done, info + + def reset(self, **kwargs): + state = super().reset() + self.history = deque(maxlen=self.history_len) + self.history.append(state) + flattened_history = self._get_history().flatten() + return flattened_history + + def _get_history(self): + observations = list(self.history) + + obs_count = len(observations) + for _ in range(self.history_len - obs_count): + dummy = np.zeros(self._wrapped_env.observation_space.low.size) + observations.append(dummy) + return np.c_[observations] + + +class DiscretizeEnv(ProxyEnv, Env): + def __init__(self, wrapped_env, num_bins): + super().__init__(wrapped_env) + low = self.wrapped_env.action_space.low + high = self.wrapped_env.action_space.high + action_ranges = [ + np.linspace(low[i], high[i], num_bins) + for i in range(len(low)) + ] + self.idx_to_continuous_action = [ + np.array(x) for x in itertools.product(*action_ranges) + ] + self.action_space = Discrete(len(self.idx_to_continuous_action)) + + def step(self, action): + continuous_action = self.idx_to_continuous_action[action] + return super().step(continuous_action) + + +class NormalizedBoxEnv(ProxyEnv): + """ + Normalize action to in [-1, 1]. + + Optionally normalize observations and scale reward. + """ + + def __init__( + self, + env, + reward_scale=1., + obs_mean=None, + obs_std=None, + ): + ProxyEnv.__init__(self, env) + self._should_normalize = not (obs_mean is None and obs_std is None) + if self._should_normalize: + if obs_mean is None: + obs_mean = np.zeros_like(env.observation_space.low) + else: + obs_mean = np.array(obs_mean) + if obs_std is None: + obs_std = np.ones_like(env.observation_space.low) + else: + obs_std = np.array(obs_std) + self._reward_scale = reward_scale + self._obs_mean = obs_mean + self._obs_std = obs_std + ub = np.ones(self._wrapped_env.action_space.shape) + self.action_space = Box(-1 * ub, ub) + + def estimate_obs_stats(self, obs_batch, override_values=False): + if self._obs_mean is not None and not override_values: + raise Exception("Observation mean and std already set. To " + "override, set override_values to True.") + self._obs_mean = np.mean(obs_batch, axis=0) + self._obs_std = np.std(obs_batch, axis=0) + + def _apply_normalize_obs(self, obs): + return (obs - self._obs_mean) / (self._obs_std + 1e-8) + + def step(self, action): + lb = self._wrapped_env.action_space.low + ub = self._wrapped_env.action_space.high + scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) + scaled_action = np.clip(scaled_action, lb, ub) + + wrapped_step = self._wrapped_env.step(scaled_action) + next_obs, reward, done, info = wrapped_step + if self._should_normalize: + next_obs = self._apply_normalize_obs(next_obs) + return next_obs, reward * self._reward_scale, done, info + + def __str__(self): + return "Normalized: %s" % self._wrapped_env diff --git a/d4rl/scripts/check_antmaze_datasets.py b/d4rl/scripts/check_antmaze_datasets.py new file mode 100644 index 0000000..ea4070b --- /dev/null +++ b/d4rl/scripts/check_antmaze_datasets.py @@ -0,0 +1,97 @@ +""" +This script runs sanity checks all datasets in a directory. + +Usage: + +python check_antmaze_datasets.py +""" +import numpy as np +import scipy as sp +import scipy.spatial +import h5py +import os +import argparse + + +def check_identical_values(dset): + """ Check that values are not identical """ + check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel'] + + for k in check_keys: + values = dset[k][:] + + values_0 = values[0] + values_mid = values[values.shape[0]//2] + values_last = values[-1] + values = np.c_[values_0, values_mid, values_last].T + dists = sp.spatial.distance.pdist(values) + not_same = dists > 0 + assert np.all(not_same) + + +def check_num_samples(dset): + """ Check that all keys have the same # samples """ + check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel'] + + N = None + for k in check_keys: + values = dset[k] + if N is None: + N = values.shape[0] + else: + assert values.shape[0] == N + + +def check_reset_nonterminal(dataset): + """ Check if a reset occured on a non-terminal state.""" + positions = dataset['observations'][:-1,0:2] + next_positions = dataset['observations'][1:,0:2] + diffs = np.linalg.norm(positions-next_positions, axis=1) + terminal = ((dataset['terminals'][:] + dataset['timeouts'][:]) > 0)[:-1] + + num_resets = np.sum(diffs > 5.0) + num_nonterminal_reset = np.sum( (diffs > 5.0) * (1-terminal)) + + print('num reset:', num_resets) + print('nonreset term:', num_nonterminal_reset) + + assert num_nonterminal_reset == 0 + +def print_avg_returns(dset): + """ Print returns for manual sanity checking. """ + rew = dset['rewards'][:] + terminals = dset['terminals'][:] + timeouts = dset['timeouts'][:] + end_episode = (timeouts + terminals) > 0 + + all_returns = [] + returns = 0 + for i in range(rew.shape[0]): + returns += float(rew[i]) + if end_episode[i]: + all_returns.append(returns) + returns = 0 + print('Avg returns:', np.mean(all_returns)) + print('# timeout:', np.sum(timeouts)) + print('# terminals:', np.sum(terminals)) + + +CHECK_FNS = [print_avg_returns, check_reset_nonterminal, check_identical_values, check_num_samples] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets') + args = parser.parse_args() + dirname = args.dirname + for fname in os.listdir(dirname): + if fname.endswith('.hdf5'): + hfile = h5py.File(os.path.join(dirname, fname)) + print('Checking:', fname) + for check_fn in CHECK_FNS: + try: + check_fn(hfile) + except AssertionError as e: + print('Failed test:', check_fn.__name__) + #raise e + diff --git a/d4rl/scripts/check_bullet.py b/d4rl/scripts/check_bullet.py new file mode 100644 index 0000000..d2c63ed --- /dev/null +++ b/d4rl/scripts/check_bullet.py @@ -0,0 +1,61 @@ +""" +A quick script to run a sanity check on all environments. +""" +import gym +import d4rl +import numpy as np + +ENVS = [ + 'bullet-halfcheetah-random-v0', + 'bullet-halfcheetah-medium-v0', + 'bullet-halfcheetah-expert-v0', + 'bullet-halfcheetah-medium-replay-v0', + 'bullet-halfcheetah-medium-expert-v0', + 'bullet-walker2d-random-v0', + 'bullet-walker2d-medium-v0', + 'bullet-walker2d-expert-v0', + 'bullet-walker2d-medium-replay-v0', + 'bullet-walker2d-medium-expert-v0', + 'bullet-hopper-random-v0', + 'bullet-hopper-medium-v0', + 'bullet-hopper-expert-v0', + 'bullet-hopper-medium-replay-v0', + 'bullet-hopper-medium-expert-v0', + 'bullet-ant-random-v0', + 'bullet-ant-medium-v0', + 'bullet-ant-expert-v0', + 'bullet-ant-medium-replay-v0', + 'bullet-ant-medium-expert-v0', + 'bullet-maze2d-open-v0', + 'bullet-maze2d-umaze-v0', + 'bullet-maze2d-medium-v0', + 'bullet-maze2d-large-v0', +] + +if __name__ == '__main__': + for env_name in ENVS: + print('Checking', env_name) + try: + env = gym.make(env_name) + except Exception as e: + print(e) + continue + dset = env.get_dataset() + print('\t Max episode steps:', env._max_episode_steps) + print('\t',dset['observations'].shape, dset['actions'].shape) + assert 'observations' in dset, 'Observations not in dataset' + assert 'actions' in dset, 'Actions not in dataset' + assert 'rewards' in dset, 'Rewards not in dataset' + assert 'terminals' in dset, 'Terminals not in dataset' + N = dset['observations'].shape[0] + print('\t %d samples' % N) + assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) + assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) + assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) + print('\t num terminals: %d' % np.sum(dset['terminals'])) + print('\t avg rew: %f' % np.mean(dset['rewards'])) + + env.reset() + env.step(env.action_space.sample()) + score = env.get_normalized_score(0.0) + diff --git a/d4rl/scripts/check_envs.py b/d4rl/scripts/check_envs.py new file mode 100644 index 0000000..497816e --- /dev/null +++ b/d4rl/scripts/check_envs.py @@ -0,0 +1,92 @@ +""" +A quick script to run a sanity check on all environments. +""" +import gym +import d4rl +import numpy as np + +ENVS = [] + +for agent in ['halfcheetah', 'hopper', 'walker2d', 'ant']: + for dataset in ['random', 'medium', 'expert', 'medium-replay', 'full-replay', 'medium-expert']: + ENVS.append(agent+'-'+dataset+'-v1') + +for agent in ['door', 'pen', 'relocate', 'hammer']: + for dataset in ['expert', 'cloned', 'human']: + ENVS.append(agent+'-'+dataset+'-v1') + +ENVS.extend([ + 'maze2d-open-v0', + 'maze2d-umaze-v1', + 'maze2d-medium-v1', + 'maze2d-large-v1', + 'maze2d-open-dense-v0', + 'maze2d-umaze-dense-v1', + 'maze2d-medium-dense-v1', + 'maze2d-large-dense-v1', + 'minigrid-fourrooms-v0', + 'minigrid-fourrooms-random-v0', + 'pen-human-v0', + 'pen-cloned-v0', + 'pen-expert-v0', + 'hammer-human-v0', + 'hammer-cloned-v0', + 'hammer-expert-v0', + 'relocate-human-v0', + 'relocate-cloned-v0', + 'relocate-expert-v0', + 'door-human-v0', + 'door-cloned-v0', + 'door-expert-v0', + 'antmaze-umaze-v0', + 'antmaze-umaze-diverse-v0', + 'antmaze-medium-play-v0', + 'antmaze-medium-diverse-v0', + 'antmaze-large-play-v0', + 'antmaze-large-diverse-v0', + 'mini-kitchen-microwave-kettle-light-slider-v0', + 'kitchen-microwave-kettle-light-slider-v0', + 'kitchen-microwave-kettle-bottomburner-light-v0', +]) + +if __name__ == '__main__': + for env_name in ENVS: + print('Checking', env_name) + try: + env = gym.make(env_name) + except Exception as e: + print(e) + continue + dset = env.get_dataset() + print('\t Max episode steps:', env._max_episode_steps) + print('\t',dset['observations'].shape, dset['actions'].shape) + assert 'observations' in dset, 'Observations not in dataset' + assert 'actions' in dset, 'Actions not in dataset' + assert 'rewards' in dset, 'Rewards not in dataset' + assert 'terminals' in dset, 'Terminals not in dataset' + N = dset['observations'].shape[0] + print('\t %d samples' % N) + assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) + assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) + assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) + orig_terminals = np.sum(dset['terminals']) + print('\t num terminals: %d' % np.sum(dset['terminals'])) + + env.reset() + env.step(env.action_space.sample()) + score = env.get_normalized_score(0.0) + + dset = d4rl.qlearning_dataset(env, dataset=dset) + assert 'observations' in dset, 'Observations not in dataset' + assert 'next_observations' in dset, 'Observations not in dataset' + assert 'actions' in dset, 'Actions not in dataset' + assert 'rewards' in dset, 'Rewards not in dataset' + assert 'terminals' in dset, 'Terminals not in dataset' + N = dset['observations'].shape[0] + print('\t %d samples' % N) + assert dset['next_observations'].shape[0] == N, 'NextObs number does not match (%d vs %d)' % (dset['actions'].shape[0], N) + assert dset['actions'].shape[0] == N, 'Action number does not match (%d vs %d)' % (dset['actions'].shape[0], N) + assert dset['rewards'].shape[0] == N, 'Reward number does not match (%d vs %d)' % (dset['rewards'].shape[0], N) + assert dset['terminals'].shape[0] == N, 'Terminals number does not match (%d vs %d)' % (dset['terminals'].shape[0], N) + print('\t num terminals: %d' % np.sum(dset['terminals'])) + assert orig_terminals == np.sum(dset['terminals']), 'Qlearining terminals doesnt match original terminals' diff --git a/d4rl/scripts/check_mujoco_datasets.py b/d4rl/scripts/check_mujoco_datasets.py new file mode 100644 index 0000000..53fcb79 --- /dev/null +++ b/d4rl/scripts/check_mujoco_datasets.py @@ -0,0 +1,133 @@ +""" +This script runs sanity checks all datasets in a directory. +Assumes all datasets in the directory are generated via mujoco and contain +the qpos/qvel keys. + +Usage: + +python check_mujoco_datasets.py +""" +import numpy as np +import scipy as sp +import scipy.spatial +import h5py +import os +import argparse +import tqdm + + +def check_identical_values(dset): + """ Check that values are not identical """ + check_keys = ['actions', 'observations', 'infos/qpos', 'infos/qvel'] + + for k in check_keys: + values = dset[k][:] + + values_0 = values[0] + values_mid = values[values.shape[0]//2] + values_last = values[-1] + values = np.c_[values_0, values_mid, values_last].T + dists = sp.spatial.distance.pdist(values) + not_same = dists > 0 + assert np.all(not_same) + + +def check_qpos_qvel(dset): + """ Check that qpos/qvel produces correct state""" + import gym + import d4rl + + N = dset['rewards'].shape[0] + qpos = dset['infos/qpos'] + qvel = dset['infos/qvel'] + obs = dset['observations'] + + reverse_env_map = {v.split('/')[-1]: k for (k, v) in d4rl.infos.DATASET_URLS.items()} + env_name = reverse_env_map[dset.filename.split('/')[-1]] + env = gym.make(env_name) + env.reset() + print('checking qpos/qvel') + for t in tqdm.tqdm(range(N)): + env.set_state(qpos[t], qvel[t]) + env_obs = env.env.wrapped_env._get_obs() + error = ((obs[t] - env_obs)**2).sum() + assert error < 1e-8 + +def check_num_samples(dset): + """ Check that all keys have the same # samples """ + check_keys = ['actions', 'observations', 'rewards', 'timeouts', 'terminals', 'infos/qpos', 'infos/qvel'] + + N = None + for k in check_keys: + values = dset[k] + if N is None: + N = values.shape[0] + else: + assert values.shape[0] == N + + +def check_reset_state(dset): + """ Check that resets correspond approximately to the initial state """ + obs = dset['observations'][:] + N = obs.shape[0] + terminals = dset['terminals'][:] + timeouts = dset['timeouts'][:] + end_episode = (timeouts + terminals) > 0 + + # Use the first observation as a reference initial state + reset_state = obs[0] + + # Make sure all reset observations are close to the reference initial state + + # Take up to [:-1] in case last entry in dataset is terminal + end_idxs = np.where(end_episode)[0][:-1] + + diffs = obs[1:] - reset_state + dists = np.linalg.norm(diffs, axis=1) + + min_dist = np.min(dists) + reset_dists = dists[end_idxs] #don't add idx +1 because we took the obs[:1] slice + print('max reset:', np.max(reset_dists)) + print('min reset:', np.min(reset_dists)) + + assert np.all(reset_dists < (min_dist + 1e-2) * 5) + + +def print_avg_returns(dset): + """ Print returns for manual sanity checking. """ + rew = dset['rewards'][:] + terminals = dset['terminals'][:] + timeouts = dset['timeouts'][:] + end_episode = (timeouts + terminals) > 0 + + all_returns = [] + returns = 0 + for i in range(rew.shape[0]): + returns += float(rew[i]) + if end_episode[i]: + all_returns.append(returns) + returns = 0 + print('Avg returns:', np.mean(all_returns)) + print('# timeout:', np.sum(timeouts)) + print('# terminals:', np.sum(terminals)) + + +CHECK_FNS = [print_avg_returns, check_qpos_qvel, check_reset_state, check_identical_values, check_num_samples] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('dirname', type=str, help='Directory containing HDF5 datasets') + args = parser.parse_args() + dirname = args.dirname + for fname in os.listdir(dirname): + if fname.endswith('.hdf5'): + hfile = h5py.File(os.path.join(dirname, fname)) + print('Checking:', fname) + for check_fn in CHECK_FNS: + try: + check_fn(hfile) + except AssertionError as e: + print('Failed test:', check_fn.__name__) + raise e + diff --git a/d4rl/scripts/generation/flow_idm.py b/d4rl/scripts/generation/flow_idm.py new file mode 100644 index 0000000..7963576 --- /dev/null +++ b/d4rl/scripts/generation/flow_idm.py @@ -0,0 +1,69 @@ +import numpy as np +import argparse +import gym +import d4rl.flow +from d4rl.utils import dataset_utils + +from flow.controllers import car_following_models + + +def main(): + parser = argparse.ArgumentParser() + #parser.add_argument('--render', action='store_true', help='Render trajectories') + #parser.add_argument('--type', action='store_true', help='Noisy actions') + parser.add_argument('--controller', type=str, default='idm', help='random, idm') + parser.add_argument('--env_name', type=str, default='flow-ring-v0', help='Maze type. small or default') + parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') + args = parser.parse_args() + + env = gym.make(args.env_name) + env.reset() + print(env.action_space) + + + if args.controller == 'idm': + uenv = env.unwrapped + veh_ids = uenv.k.vehicle.get_rl_ids() + if hasattr(uenv, 'num_rl'): + num_rl = uenv.num_rl + else: + num_rl = len(veh_ids) + if num_rl == 0: + raise ValueError("No RL vehicles") + controllers = [] + + acc_controller = uenv.k.vehicle.get_acc_controller(uenv.k.vehicle.get_ids()[0]) + car_following_params = acc_controller.car_following_params + #for veh_id in veh_ids: + # controllers.append(car_following_models.IDMController(veh_id, car_following_params=car_following_params)) + + def get_action(s): + actions = np.zeros_like(env.action_space.sample()) + for i, veh_id in enumerate(uenv.k.vehicle.get_rl_ids()): + if i >= actions.shape[0]: + break + actions[i] = car_following_models.IDMController(veh_id, car_following_params=car_following_params).get_accel(env) + return actions + elif args.controller == 'random': + def get_action(s): + return env.action_space.sample() + else: + raise ValueError("Unknown controller type: %s" % str(args.controller)) + + writer = dataset_utils.DatasetWriter() + while len(writer) < args.num_samples: + s = env.reset() + ret = 0 + for _ in range(env._max_episode_steps): + action = get_action(s) + ns , r, done, infos = env.step(action) + ret += r + writer.append_data(s, action, r, done) + s = ns + print(ret) + #env.render() + fname = '%s-%s.hdf5' % (args.env_name, args.controller) + writer.write_dataset(fname, max_size=args.num_samples) + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/generation/generate_ant_maze_datasets.py b/d4rl/scripts/generation/generate_ant_maze_datasets.py new file mode 100644 index 0000000..6c567d5 --- /dev/null +++ b/d4rl/scripts/generation/generate_ant_maze_datasets.py @@ -0,0 +1,174 @@ +import numpy as np +import pickle +import gzip +import h5py +import argparse +from d4rl.locomotion import maze_env, ant, swimmer +from d4rl.locomotion.wrappers import NormalizedBoxEnv +import torch +from PIL import Image +import os + + +def reset_data(): + return {'observations': [], + 'actions': [], + 'terminals': [], + 'timeouts': [], + 'rewards': [], + 'infos/goal': [], + 'infos/qpos': [], + 'infos/qvel': [], + } + +def append_data(data, s, a, r, tgt, done, timeout, env_data): + data['observations'].append(s) + data['actions'].append(a) + data['rewards'].append(r) + data['terminals'].append(done) + data['timeouts'].append(timeout) + data['infos/goal'].append(tgt) + data['infos/qpos'].append(env_data.qpos.ravel().copy()) + data['infos/qvel'].append(env_data.qvel.ravel().copy()) + +def npify(data): + for k in data: + if k in ['terminals', 'timeouts']: + dtype = np.bool_ + else: + dtype = np.float32 + + data[k] = np.array(data[k], dtype=dtype) + +def load_policy(policy_file): + data = torch.load(policy_file) + policy = data['exploration/policy'].to('cpu') + env = data['evaluation/env'] + print("Policy loaded") + return policy, env + +def save_video(save_dir, file_name, frames, episode_id=0): + filename = os.path.join(save_dir, file_name+ '_episode_{}'.format(episode_id)) + if not os.path.exists(filename): + os.makedirs(filename) + num_frames = frames.shape[0] + for i in range(num_frames): + img = Image.fromarray(np.flipud(frames[i]), 'RGB') + img.save(os.path.join(filename, 'frame_{}.png'.format(i))) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--noisy', action='store_true', help='Noisy actions') + parser.add_argument('--maze', type=str, default='umaze', help='Maze type. umaze, medium, or large') + parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') + parser.add_argument('--env', type=str, default='Ant', help='Environment type') + parser.add_argument('--policy_file', type=str, default='policy_file', help='file_name') + parser.add_argument('--max_episode_steps', default=1000, type=int) + parser.add_argument('--video', action='store_true') + parser.add_argument('--multi_start', action='store_true') + parser.add_argument('--multigoal', action='store_true') + args = parser.parse_args() + + if args.maze == 'umaze': + maze = maze_env.U_MAZE + elif args.maze == 'medium': + maze = maze_env.BIG_MAZE + elif args.maze == 'large': + maze = maze_env.HARDEST_MAZE + elif args.maze == 'umaze_eval': + maze = maze_env.U_MAZE_EVAL + elif args.maze == 'medium_eval': + maze = maze_env.BIG_MAZE_EVAL + elif args.maze == 'large_eval': + maze = maze_env.HARDEST_MAZE_EVAL + else: + raise NotImplementedError + + if args.env == 'Ant': + env = NormalizedBoxEnv(ant.AntMazeEnv(maze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) + elif args.env == 'Swimmer': + env = NormalizedBoxEnv(swimmer.SwimmerMazeEnv(mmaze_map=maze, maze_size_scaling=4.0, non_zero_reset=args.multi_start)) + else: + raise NotImplementedError + + env.set_target() + s = env.reset() + act = env.action_space.sample() + done = False + + # Load the policy + policy, train_env = load_policy(args.policy_file) + + # Define goal reaching policy fn + def _goal_reaching_policy_fn(obs, goal): + goal_x, goal_y = goal + obs_new = obs[2:-2] + goal_tuple = np.array([goal_x, goal_y]) + + # normalize the norm of the relative goals to in-distribution values + goal_tuple = goal_tuple / np.linalg.norm(goal_tuple) * 10.0 + + new_obs = np.concatenate([obs_new, goal_tuple], -1) + return policy.get_action(new_obs)[0], (goal_tuple[0] + obs[0], goal_tuple[1] + obs[1]) + + data = reset_data() + + # create waypoint generating policy integrated with high level controller + data_collection_policy = env.create_navigation_policy( + _goal_reaching_policy_fn, + ) + + if args.video: + frames = [] + + ts = 0 + num_episodes = 0 + for _ in range(args.num_samples): + act, waypoint_goal = data_collection_policy(s) + + if args.noisy: + act = act + np.random.randn(*act.shape)*0.2 + act = np.clip(act, -1.0, 1.0) + + ns, r, done, info = env.step(act) + timeout = False + if ts >= args.max_episode_steps: + timeout = True + #done = True + + append_data(data, s[:-2], act, r, env.target_goal, done, timeout, env.physics.data) + + if len(data['observations']) % 10000 == 0: + print(len(data['observations'])) + + ts += 1 + + if done or timeout: + done = False + ts = 0 + s = env.reset() + env.set_target_goal() + if args.video: + frames = np.array(frames) + save_video('./videos/', args.env + '_navigation', frames, num_episodes) + + num_episodes += 1 + frames = [] + else: + s = ns + + if args.video: + curr_frame = env.physics.render(width=500, height=500, depth=False) + frames.append(curr_frame) + + if args.noisy: + fname = args.env + '_maze_%s_noisy_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) + else: + fname = args.env + 'maze_%s_multistart_%s_multigoal_%s.hdf5' % (args.maze, str(args.multi_start), str(args.multigoal)) + dataset = h5py.File(fname, 'w') + npify(data) + for k in data: + dataset.create_dataset(k, data=data[k], compression='gzip') + +if __name__ == '__main__': + main() diff --git a/d4rl/scripts/generation/generate_kitchen_datasets.py b/d4rl/scripts/generation/generate_kitchen_datasets.py new file mode 100644 index 0000000..7a0d70d --- /dev/null +++ b/d4rl/scripts/generation/generate_kitchen_datasets.py @@ -0,0 +1,144 @@ +"""Script for generating the datasets for kitchen environments.""" +import d4rl.kitchen +import glob +import gym +import h5py +import numpy as np +import os +import pickle + +np.set_printoptions(precision=2, suppress=True) + +SAVE_DIRECTORY = '~/.offline_rl/datasets' +DEMOS_DIRECTORY = '~/relay-policy-learning/kitchen_demos_multitask' +DEMOS_SUBDIR_PATTERN = '*' +ENVIRONMENTS = ['kitchen_microwave_kettle_light_slider-v0', + 'kitchen_microwave_kettle_bottomburner_light-v0'] +# Uncomment lines below for "mini_kitchen_microwave_kettle_light_slider-v0'". +DEMOS_SUBDIR_PATTERN = '*microwave_kettle_switch_slide' +ENVIRONMENTS = ['mini_kitchen_microwave_kettle_light_slider-v0'] + +OBS_ELEMENT_INDICES = [ + [11, 12], # Bottom burners. + [15, 16], # Top burners. + [17, 18], # Light switch. + [19], # Slide. + [20, 21], # Hinge. + [22], # Microwave. + [23, 24, 25, 26, 27, 28, 29], # Kettle. +] +FLAT_OBS_ELEMENT_INDICES = sum(OBS_ELEMENT_INDICES, []) + +def _relabel_obs_with_goal(obs_array, goal): + obs_array[..., 30:] = goal + return obs_array + + +def _obs_array_to_obs_dict(obs_array, goal=None): + obs_dict = { + 'qp': obs_array[:9], + 'obj_qp': obs_array[9:30], + 'goal': goal, + } + if obs_dict['goal'] is None: + obs_dict['goal'] = obs_array[30:] + return obs_dict + + +def main(): + pattern = os.path.join(DEMOS_DIRECTORY, DEMOS_SUBDIR_PATTERN) + demo_subdirs = sorted(glob.glob(pattern)) + print('Found %d demo subdirs.' % len(demo_subdirs)) + all_demos = {} + for demo_subdir in demo_subdirs: + demo_files = glob.glob(os.path.join(demo_subdir, '*.pkl')) + print('Found %d demos in %s.' % (len(demo_files), demo_subdir)) + demos = [] + for demo_file in demo_files: + with open(demo_file, 'rb') as f: + demo = pickle.load(f) + demos.append(demo) + all_demos[demo_subdir] = demos + + # For debugging... + all_observations = [demo['observations'] for demo in demos] + first_elements = [obs[0, FLAT_OBS_ELEMENT_INDICES] + for obs in all_observations] + last_elements = [obs[-1, FLAT_OBS_ELEMENT_INDICES] + for obs in all_observations] + # End for debugging. + + for env_name in ENVIRONMENTS: + env = gym.make(env_name).unwrapped + env.REMOVE_TASKS_WHEN_COMPLETE = False # This enables a Markovian reward. + all_obs = [] + all_actions = [] + all_rewards = [] + all_terminals = [] + all_infos = [] + print('Relabelling data for %s.' % env_name) + for demo_subdir, demos in all_demos.items(): + print('On demo from %s.' % demo_subdir) + demos_obs = [] + demos_actions = [] + demos_rewards = [] + demos_terminals = [] + demos_infos = [] + for idx, demo in enumerate(demos): + env_goal = env._get_task_goal() + rewards = [] + relabelled_obs = _relabel_obs_with_goal(demo['observations'], env_goal) + for obs in relabelled_obs: + reward_dict, score = env._get_reward_n_score( + _obs_array_to_obs_dict(obs)) + + rewards.append(reward_dict['r_total']) + terminate_at = len(rewards) + rewards = rewards[:terminate_at] + demos_obs.append(relabelled_obs[:terminate_at]) + demos_actions.append(demo['actions'][:terminate_at]) + demos_rewards.append(np.array(rewards)) + demos_terminals.append(np.arange(len(rewards)) >= len(rewards) - 1) + demos_infos.append([idx] * len(rewards)) + + all_obs.append(np.concatenate(demos_obs)) + all_actions.append(np.concatenate(demos_actions)) + all_rewards.append(np.concatenate(demos_rewards)) + all_terminals.append(np.concatenate(demos_terminals)) + all_infos.append(np.concatenate(demos_infos)) + + episode_rewards = [np.sum(rewards) for rewards in demos_rewards] + last_rewards = [rewards[-1] for rewards in demos_rewards] + print('Avg episode rewards %f.' % np.mean(episode_rewards)) + print('Avg last step rewards %f.' % np.mean(last_rewards)) + + dataset_obs = np.concatenate(all_obs).astype('float32') + dataset_actions = np.concatenate(all_actions).astype('float32') + dataset_rewards = np.concatenate(all_rewards).astype('float32') + dataset_terminals = np.concatenate(all_terminals).astype('float32') + dataset_infos = np.concatenate(all_infos) + dataset_size = len(dataset_obs) + assert dataset_size == len(dataset_actions) + assert dataset_size == len(dataset_rewards) + assert dataset_size == len(dataset_terminals) + assert dataset_size == len(dataset_infos) + + dataset = { + 'observations': dataset_obs, + 'actions': dataset_actions, + 'rewards': dataset_rewards, + 'terminals': dataset_terminals, + 'infos': dataset_infos, + } + + print('Generated dataset with %d total steps.' % dataset_size) + save_filename = os.path.join(SAVE_DIRECTORY, '%s.hdf5' % env_name) + print('Saving dataset to %s.' % save_filename) + h5_dataset = h5py.File(save_filename, 'w') + for key in dataset: + h5_dataset.create_dataset(key, data=dataset[key], compression='gzip') + print('Done.') + + +if __name__ == '__main__': + main() diff --git a/d4rl/scripts/generation/generate_maze2d_bullet_datasets.py b/d4rl/scripts/generation/generate_maze2d_bullet_datasets.py new file mode 100644 index 0000000..d55954a --- /dev/null +++ b/d4rl/scripts/generation/generate_maze2d_bullet_datasets.py @@ -0,0 +1,115 @@ +import gym +import logging +from d4rl.pointmaze import waypoint_controller +from d4rl.pointmaze_bullet import bullet_maze +from d4rl.pointmaze import maze_model +import numpy as np +import pickle +import gzip +import h5py +import argparse +import time + + +def reset_data(): + return {'observations': [], + 'actions': [], + 'terminals': [], + 'timeouts': [], + 'rewards': [], + 'infos/goal': [], + 'infos/qpos': [], + 'infos/qvel': [], + } + +def append_data(data, s, a, tgt, done, timeout, robot): + data['observations'].append(s) + data['actions'].append(a) + data['rewards'].append(0.0) + data['terminals'].append(False) + data['timeouts'].append(False) + data['infos/goal'].append(tgt) + data['infos/goal_reached'].append(done) + data['infos/goal_timeout'].append(timeout) + data['infos/qpos'].append(robot.qpos.copy()) + data['infos/qvel'].append(robot.qvel.copy()) + +def npify(data): + for k in data: + if k == 'terminals' or k == 'timeouts': + dtype = np.bool_ + else: + dtype = np.float32 + + data[k] = np.array(data[k], dtype=dtype) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--render', action='store_true', help='Render trajectories') + parser.add_argument('--noisy', action='store_true', help='Noisy actions') + parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type') + parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') + args = parser.parse_args() + + env = gym.make(args.env_name) + maze = env.str_maze_spec + max_episode_steps = env._max_episode_steps + + # default: p=10, d=-1 + controller = waypoint_controller.WaypointController(maze, p_gain=10.0, d_gain=-2.0) + env = bullet_maze.Maze2DBulletEnv(maze) + if args.render: + env.render('human') + + env.set_target() + s = env.reset() + act = env.action_space.sample() + timeout = False + + data = reset_data() + last_position = s[0:2] + ts = 0 + for _ in range(args.num_samples): + position = s[0:2] + velocity = s[2:4] + + # subtract 1.0 due to offset between tabular maze representation and bullet state + act, done = controller.get_action(position , velocity, env._target) + if args.noisy: + act = act + np.random.randn(*act.shape)*0.5 + + act = np.clip(act, -1.0, 1.0) + if ts >= max_episode_steps: + timeout = True + append_data(data, s, act, env._target, done, timeout, env.robot) + + ns, _, _, _ = env.step(act) + + if len(data['observations']) % 10000 == 0: + print(len(data['observations'])) + + ts += 1 + if done: + env.set_target() + done = False + ts = 0 + else: + last_position = s[0:2] + s = ns + + if args.render: + env.render('human') + + + if args.noisy: + fname = '%s-noisy-bullet.hdf5' % args.env_name + else: + fname = '%s-bullet.hdf5' % args.env_name + dataset = h5py.File(fname, 'w') + npify(data) + for k in data: + dataset.create_dataset(k, data=data[k], compression='gzip') + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/generation/generate_maze2d_datasets.py b/d4rl/scripts/generation/generate_maze2d_datasets.py new file mode 100644 index 0000000..29d72f5 --- /dev/null +++ b/d4rl/scripts/generation/generate_maze2d_datasets.py @@ -0,0 +1,102 @@ +import gym +import logging +from d4rl.pointmaze import waypoint_controller +from d4rl.pointmaze import maze_model +import numpy as np +import pickle +import gzip +import h5py +import argparse + + +def reset_data(): + return {'observations': [], + 'actions': [], + 'terminals': [], + 'rewards': [], + 'infos/goal': [], + 'infos/qpos': [], + 'infos/qvel': [], + } + +def append_data(data, s, a, tgt, done, env_data): + data['observations'].append(s) + data['actions'].append(a) + data['rewards'].append(0.0) + data['terminals'].append(done) + data['infos/goal'].append(tgt) + data['infos/qpos'].append(env_data.qpos.ravel().copy()) + data['infos/qvel'].append(env_data.qvel.ravel().copy()) + +def npify(data): + for k in data: + if k == 'terminals': + dtype = np.bool_ + else: + dtype = np.float32 + + data[k] = np.array(data[k], dtype=dtype) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--render', action='store_true', help='Render trajectories') + parser.add_argument('--noisy', action='store_true', help='Noisy actions') + parser.add_argument('--env_name', type=str, default='maze2d-umaze-v1', help='Maze type') + parser.add_argument('--num_samples', type=int, default=int(1e6), help='Num samples to collect') + args = parser.parse_args() + + env = gym.make(args.env_name) + maze = env.str_maze_spec + max_episode_steps = env._max_episode_steps + + controller = waypoint_controller.WaypointController(maze) + env = maze_model.MazeEnv(maze) + + env.set_target() + s = env.reset() + act = env.action_space.sample() + done = False + + data = reset_data() + ts = 0 + for _ in range(args.num_samples): + position = s[0:2] + velocity = s[2:4] + act, done = controller.get_action(position, velocity, env._target) + if args.noisy: + act = act + np.random.randn(*act.shape)*0.5 + + act = np.clip(act, -1.0, 1.0) + if ts >= max_episode_steps: + done = True + append_data(data, s, act, env._target, done, env.sim.data) + + ns, _, _, _ = env.step(act) + + if len(data['observations']) % 10000 == 0: + print(len(data['observations'])) + + ts += 1 + if done: + env.set_target() + done = False + ts = 0 + else: + s = ns + + if args.render: + env.render() + + + if args.noisy: + fname = '%s-noisy.hdf5' % args.env_name + else: + fname = '%s.hdf5' % args.env_name + dataset = h5py.File(fname, 'w') + npify(data) + for k in data: + dataset.create_dataset(k, data=data[k], compression='gzip') + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/generation/generate_minigrid_fourroom_data.py b/d4rl/scripts/generation/generate_minigrid_fourroom_data.py new file mode 100644 index 0000000..5fc874b --- /dev/null +++ b/d4rl/scripts/generation/generate_minigrid_fourroom_data.py @@ -0,0 +1,93 @@ +import logging +from offline_rl.gym_minigrid import fourroom_controller +from offline_rl.gym_minigrid.envs import fourrooms +import numpy as np +import pickle +import gzip +import h5py +import argparse + + +def reset_data(): + return {'observations': [], + 'actions': [], + 'terminals': [], + 'rewards': [], + 'infos/goal': [], + 'infos/pos': [], + 'infos/orientation': [], + } + +def append_data(data, s, a, tgt, done, pos, ori): + data['observations'].append(s) + data['actions'].append(a) + data['rewards'].append(0.0) + data['terminals'].append(done) + data['infos/goal'].append(tgt) + data['infos/pos'].append(pos) + data['infos/orientation'].append(ori) + +def npify(data): + for k in data: + if k == 'terminals': + dtype = np.bool_ + else: + dtype = np.float32 + + data[k] = np.array(data[k], dtype=dtype) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--render', action='store_true', help='Render trajectories') + parser.add_argument('--random', action='store_true', help='Noisy actions') + parser.add_argument('--num_samples', type=int, default=int(1e5), help='Num samples to collect') + args = parser.parse_args() + + controller = fourroom_controller.FourRoomController() + env = fourrooms.FourRoomsEnv() + + controller.set_target(controller.sample_target()) + s = env.reset() + act = env.action_space.sample() + done = False + + data = reset_data() + ts = 0 + for _ in range(args.num_samples): + if args.render: + env.render() + + if args.random: + act = env.action_space.sample() + else: + act, done = controller.get_action(env.agent_pos, env.agent_dir) + + if ts >= 50: + done = True + append_data(data, s['image'], act, controller.target, done, env.agent_pos, env.agent_dir) + + ns, _, _, _ = env.step(act) + + if len(data['observations']) % 10000 == 0: + print(len(data['observations'])) + + ts += 1 + if done: + controller.set_target(controller.sample_target()) + done = False + ts = 0 + else: + s = ns + + if args.random: + fname = 'minigrid4rooms_random.hdf5' + else: + fname = 'minigrid4rooms.hdf5' + dataset = h5py.File(fname, 'w') + npify(data) + for k in data: + dataset.create_dataset(k, data=data[k], compression='gzip') + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/generation/hand_dapg_combined.py b/d4rl/scripts/generation/hand_dapg_combined.py new file mode 100644 index 0000000..55ba189 --- /dev/null +++ b/d4rl/scripts/generation/hand_dapg_combined.py @@ -0,0 +1,70 @@ +import gym +import d4rl +import argparse +import os +import numpy as np +import h5py + +def get_keys(h5file): + keys = [] + def visitor(name, item): + if isinstance(item, h5py.Dataset): + keys.append(name) + h5file.visititems(visitor) + return keys + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--env_name', type=str, default='pen', help='Env name') + parser.add_argument('--bc', type=str, help='BC hdf5 dataset') + parser.add_argument('--human', type=str, help='Human demos hdf5 dataset') + args = parser.parse_args() + + env = gym.make('%s-v0' % args.env_name) + human_dataset = h5py.File(args.human, 'r') + bc_dataset = h5py.File(args.bc, 'r') + N = env._max_episode_steps * 5000 + + # search for nearest terminal after the halfway mark + halfN = N // 2 + terms = bc_dataset['terminals'][:] + tos = bc_dataset['timeouts'][:] + last_term = 0 + for i in range(halfN, N): + if terms[i] or tos[i]: + last_term = i + break + halfN = last_term + 1 + + remaining_N = N - halfN + + aug_dataset = h5py.File('%s-cloned-v1.hdf5' % args.env_name, 'w') + for k in get_keys(bc_dataset): + if 'metadata' not in k: + human_data = human_dataset[k][:] + bc_data = bc_dataset[k][:halfN] + print(k, human_data.shape, bc_data.shape) + N_tile = int(halfN / human_data.shape[0]) + 1 + if len(human_data.shape) == 1: + human_data = np.tile(human_data, [N_tile])[:remaining_N] + elif len(human_data.shape) == 2: + human_data = np.tile(human_data, [N_tile, 1])[:remaining_N] + else: + raise NotImplementedError() + + # clone demo_data + aug_data = np.concatenate([bc_data, human_data], axis=0) + assert aug_data.shape[1:] == bc_data.shape[1:] + assert aug_data.shape[1:] == human_data.shape[1:] + + print('\t',human_data.shape, bc_data.shape, '->',aug_data.shape) + aug_dataset.create_dataset(k, data=aug_data, compression='gzip') + else: + shape = bc_dataset[k].shape + print('metadata:', k, shape) + if len(shape) == 0: + aug_dataset[k] = bc_dataset[k][()] + else: + aug_dataset[k] = bc_dataset[k][:] + diff --git a/d4rl/scripts/generation/hand_dapg_demos.py b/d4rl/scripts/generation/hand_dapg_demos.py new file mode 100644 index 0000000..1c2b092 --- /dev/null +++ b/d4rl/scripts/generation/hand_dapg_demos.py @@ -0,0 +1,101 @@ +import d4rl +import click +import os +import gym +import numpy as np +import pickle +import h5py +import collections +from mjrl.utils.gym_env import GymEnv + +DESC = ''' +Helper script to visualize demonstrations.\n +USAGE:\n + Visualizes demonstrations on the env\n + $ python utils/visualize_demos --env_name relocate-v0\n +''' + +# MAIN ========================================================= +@click.command(help=DESC) +@click.option('--env_name', type=str, help='environment to load', default='door-v0') +def main(env_name): + if env_name is "": + print("Unknown env.") + return + demos = pickle.load(open('./demonstrations/'+env_name+'_demos.pickle', 'rb')) + # render demonstrations + demo_playback(env_name, demos, clip=True) + +def demo_playback(env_name, demo_paths, clip=False): + e = gym.make(env_name) + e.reset() + + obs_ = [] + act_ = [] + rew_ = [] + term_ = [] + timeout_ = [] + info_qpos_ = [] + info_qvel_ = [] + info_env_state_ = collections.defaultdict(list) + + for i, path in enumerate(demo_paths): + e.set_env_state(path['init_state_dict']) + actions = path['actions'] + returns = 0 + for t in range(actions.shape[0]): + obs_.append(e.get_obs()) + info_qpos_.append(e.env.data.qpos.ravel().copy()) + info_qvel_.append(e.env.data.qvel.ravel().copy()) + [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] + commanded_action = actions[t] + if clip: + commanded_action = np.clip(commanded_action, -1.0, 1.0) + act_.append(commanded_action) + + _, rew, _, info = e.step(commanded_action) + returns += rew + + rew_.append(rew) + + done = False + timeout = False + if t == (actions.shape[0]-1): + timeout = True + #if t == (e._max_episode_steps-1): + # timeout = True + # done = False + + term_.append(done) + timeout_.append(timeout) + + #e.env.mj_render() # this is much faster + #e.render() + print(i, returns, returns/float(actions.shape[0])) + + # write out hdf5 file + obs_ = np.array(obs_).astype(np.float32) + act_ = np.array(act_).astype(np.float32) + rew_ = np.array(rew_).astype(np.float32) + term_ = np.array(term_).astype(np.bool_) + timeout_ = np.array(timeout_).astype(np.bool_) + info_qpos_ = np.array(info_qpos_).astype(np.float32) + info_qvel_ = np.array(info_qvel_).astype(np.float32) + + if clip: + dataset = h5py.File('%s_demos_clipped.hdf5' % env_name, 'w') + else: + dataset = h5py.File('%s_demos.hdf5' % env_name, 'w') + #dataset.create_dataset('observations', obs_.shape, dtype='f4') + dataset.create_dataset('observations', data=obs_, compression='gzip') + dataset.create_dataset('actions', data=act_, compression='gzip') + dataset.create_dataset('rewards', data=rew_, compression='gzip') + dataset.create_dataset('terminals', data=term_, compression='gzip') + dataset.create_dataset('timeouts', data=timeout_, compression='gzip') + #dataset['infos/qpos'] = info_qpos_ + #dataset['infos/qvel'] = info_qvel_ + for k in info_env_state_: + dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') + +if __name__ == '__main__': + main() diff --git a/d4rl/scripts/generation/hand_dapg_jax.py b/d4rl/scripts/generation/hand_dapg_jax.py new file mode 100644 index 0000000..9b7ac32 --- /dev/null +++ b/d4rl/scripts/generation/hand_dapg_jax.py @@ -0,0 +1,145 @@ +import d4rl +import click +import h5py +import os +import gym +import numpy as np +import pickle +import gzip +import collections +from mjrl.utils.gym_env import GymEnv + +DESC = ''' +Helper script to visualize policy (in mjrl format).\n +USAGE:\n + Visualizes policy on the env\n + $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n +''' + +# MAIN ========================================================= +@click.command(help=DESC) +@click.option('--env_name', type=str, help='environment to load', required= True) +@click.option('--snapshot_file', type=str, help='absolute path of the policy file', required=True) +@click.option('--num_trajs', type=int, help='Num trajectories', default=5000) +@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') +def main(env_name, snapshot_file, mode, num_trajs, clip=True): + e = GymEnv(env_name) + pi = pickle.load(gzip.open(snapshot_file, 'rb')) + import pdb; pdb.set_trace() + pass + # render policy + #pol_playback(env_name, pi, num_trajs, clip=clip) + + +def extract_params(policy): + + out_dict = { + 'fc0/weight': _fc0w, + 'fc0/bias': _fc0b, + 'fc1/weight': params[2].data.numpy(), + 'fc1/bias': params[3].data.numpy(), + 'last_fc/weight': _fclw, + 'last_fc/bias': _fclb, + 'last_fc_log_std/weight': _fclw, + 'last_fc_log_std/bias': _fclb, + } + return out_dict + + +def pol_playback(env_name, pi, num_trajs=100, clip=True): + e = gym.make(env_name) + e.reset() + + obs_ = [] + act_ = [] + rew_ = [] + term_ = [] + timeout_ = [] + info_qpos_ = [] + info_qvel_ = [] + info_mean_ = [] + info_logstd_ = [] + info_env_state_ = collections.defaultdict(list) + + ravg = [] + + for n in range(num_trajs): + e.reset() + returns = 0 + for t in range(e._max_episode_steps): + obs = e.get_obs() + obs_.append(obs) + info_qpos_.append(e.env.data.qpos.ravel().copy()) + info_qvel_.append(e.env.data.qvel.ravel().copy()) + [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] + action, infos = pi.get_action(obs) + action = pi.get_action(obs)[0] # eval + + if clip: + action = np.clip(action, -1, 1) + + act_.append(action) + info_mean_.append(infos['mean']) + info_logstd_.append(infos['log_std']) + + _, rew, done, info = e.step(action) + returns += rew + rew_.append(rew) + + if t == (e._max_episode_steps-1): + timeout = True + done = False + else: + timeout = False + term_.append(done) + timeout_.append(timeout) + + if done or timeout: + e.reset() + break + + #e.env.mj_render() # this is much faster + # e.render() + ravg.append(returns) + print(n, returns, t) + + # write out hdf5 file + obs_ = np.array(obs_).astype(np.float32) + act_ = np.array(act_).astype(np.float32) + rew_ = np.array(rew_).astype(np.float32) + term_ = np.array(term_).astype(np.bool_) + timeout_ = np.array(timeout_).astype(np.bool_) + info_qpos_ = np.array(info_qpos_).astype(np.float32) + info_qvel_ = np.array(info_qvel_).astype(np.float32) + info_mean_ = np.array(info_mean_).astype(np.float32) + info_logstd_ = np.array(info_logstd_).astype(np.float32) + + if clip: + dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w') + else: + dataset = h5py.File('%s_expert.hdf5' % env_name, 'w') + + #dataset.create_dataset('observations', obs_.shape, dtype='f4') + dataset.create_dataset('observations', data=obs_, compression='gzip') + dataset.create_dataset('actions', data=act_, compression='gzip') + dataset.create_dataset('rewards', data=rew_, compression='gzip') + dataset.create_dataset('terminals', data=term_, compression='gzip') + dataset.create_dataset('timeouts', data=timeout_, compression='gzip') + #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') + #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') + dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip') + dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip') + for k in info_env_state_: + dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') + + # write metadata + policy_params = extract_params(pi) + dataset['metadata/algorithm'] = np.string_('DAPG') + dataset['metadata/policy/nonlinearity'] = np.string_('tanh') + dataset['metadata/policy/output_distribution'] = np.string_('gaussian') + for k, v in policy_params.items(): + dataset['metadata/policy/'+k] = v + +if __name__ == '__main__': + main() + diff --git a/d4rl/scripts/generation/hand_dapg_policies.py b/d4rl/scripts/generation/hand_dapg_policies.py new file mode 100644 index 0000000..df4d702 --- /dev/null +++ b/d4rl/scripts/generation/hand_dapg_policies.py @@ -0,0 +1,166 @@ +import d4rl +import click +import h5py +import os +import gym +import numpy as np +import pickle +import collections +from mjrl.utils.gym_env import GymEnv + +DESC = ''' +Helper script to visualize policy (in mjrl format).\n +USAGE:\n + Visualizes policy on the env\n + $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n +''' + +# MAIN ========================================================= +@click.command(help=DESC) +@click.option('--env_name', type=str, help='environment to load', required= True) +#@click.option('--policy', type=str, help='absolute path of the policy file', required=True) +@click.option('--num_trajs', type=int, help='Num trajectories', default=5000) +@click.option('--mode', type=str, help='exploration or evaluation mode for policy', default='evaluation') +def main(env_name, mode, num_trajs, clip=True): + e = GymEnv(env_name) + policy = './policies/'+env_name+'.pickle' + pi = pickle.load(open(policy, 'rb')) + # render policy + pol_playback(env_name, pi, num_trajs, clip=clip) + + +def extract_params(policy): + params = policy.trainable_params + + in_shift = policy.model.in_shift.data.numpy() + in_scale = policy.model.in_scale.data.numpy() + out_shift = policy.model.out_shift.data.numpy() + out_scale = policy.model.out_scale.data.numpy() + + fc0w = params[0].data.numpy() + fc0b = params[1].data.numpy() + + _fc0w = np.dot(fc0w, np.diag(1.0 / in_scale)) + _fc0b = fc0b - np.dot(_fc0w, in_shift) + + assert _fc0w.shape == fc0w.shape + assert _fc0b.shape == fc0b.shape + + fclw = params[4].data.numpy() + fclb = params[5].data.numpy() + + _fclw = np.dot(np.diag(out_scale), fclw) + _fclb = fclb * out_scale + out_shift + + assert _fclw.shape == fclw.shape + assert _fclb.shape == fclb.shape + + out_dict = { + 'fc0/weight': _fc0w, + 'fc0/bias': _fc0b, + 'fc1/weight': params[2].data.numpy(), + 'fc1/bias': params[3].data.numpy(), + 'last_fc/weight': _fclw, + 'last_fc/bias': _fclb, + 'last_fc_log_std/weight': _fclw, + 'last_fc_log_std/bias': _fclb, + } + return out_dict + +def pol_playback(env_name, pi, num_trajs=100, clip=True): + e = gym.make(env_name) + e.reset() + + obs_ = [] + act_ = [] + rew_ = [] + term_ = [] + timeout_ = [] + info_qpos_ = [] + info_qvel_ = [] + info_mean_ = [] + info_logstd_ = [] + info_env_state_ = collections.defaultdict(list) + + ravg = [] + + for n in range(num_trajs): + e.reset() + returns = 0 + for t in range(e._max_episode_steps): + obs = e.get_obs() + obs_.append(obs) + info_qpos_.append(e.env.data.qpos.ravel().copy()) + info_qvel_.append(e.env.data.qvel.ravel().copy()) + [info_env_state_[k].append(v) for k,v in e.get_env_state().items()] + action, infos = pi.get_action(obs) + action = pi.get_action(obs)[0] # eval + + if clip: + action = np.clip(action, -1, 1) + + act_.append(action) + info_mean_.append(infos['mean']) + info_logstd_.append(infos['log_std']) + + _, rew, done, info = e.step(action) + returns += rew + rew_.append(rew) + + if t == (e._max_episode_steps-1): + timeout = True + done = False + else: + timeout = False + term_.append(done) + timeout_.append(timeout) + + if done or timeout: + e.reset() + break + + #e.env.mj_render() # this is much faster + # e.render() + ravg.append(returns) + print(n, returns, t) + + # write out hdf5 file + obs_ = np.array(obs_).astype(np.float32) + act_ = np.array(act_).astype(np.float32) + rew_ = np.array(rew_).astype(np.float32) + term_ = np.array(term_).astype(np.bool_) + timeout_ = np.array(timeout_).astype(np.bool_) + info_qpos_ = np.array(info_qpos_).astype(np.float32) + info_qvel_ = np.array(info_qvel_).astype(np.float32) + info_mean_ = np.array(info_mean_).astype(np.float32) + info_logstd_ = np.array(info_logstd_).astype(np.float32) + + if clip: + dataset = h5py.File('%s_expert_clip.hdf5' % env_name, 'w') + else: + dataset = h5py.File('%s_expert.hdf5' % env_name, 'w') + + #dataset.create_dataset('observations', obs_.shape, dtype='f4') + dataset.create_dataset('observations', data=obs_, compression='gzip') + dataset.create_dataset('actions', data=act_, compression='gzip') + dataset.create_dataset('rewards', data=rew_, compression='gzip') + dataset.create_dataset('terminals', data=term_, compression='gzip') + dataset.create_dataset('timeouts', data=timeout_, compression='gzip') + #dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') + #dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') + dataset.create_dataset('infos/action_mean', data=info_mean_, compression='gzip') + dataset.create_dataset('infos/action_log_std', data=info_logstd_, compression='gzip') + for k in info_env_state_: + dataset.create_dataset('infos/%s' % k, data=np.array(info_env_state_[k], dtype=np.float32), compression='gzip') + + # write metadata + policy_params = extract_params(pi) + dataset['metadata/algorithm'] = np.string_('DAPG') + dataset['metadata/policy/nonlinearity'] = np.string_('tanh') + dataset['metadata/policy/output_distribution'] = np.string_('gaussian') + for k, v in policy_params.items(): + dataset['metadata/policy/'+k] = v + +if __name__ == '__main__': + main() + diff --git a/d4rl/scripts/generation/hand_dapg_random.py b/d4rl/scripts/generation/hand_dapg_random.py new file mode 100644 index 0000000..c2179f2 --- /dev/null +++ b/d4rl/scripts/generation/hand_dapg_random.py @@ -0,0 +1,96 @@ +import brenvs +import click +import h5py +import os +import gym +import numpy as np +import pickle +from mjrl.utils.gym_env import GymEnv + +DESC = ''' +Helper script to visualize policy (in mjrl format).\n +USAGE:\n + Visualizes policy on the env\n + $ python utils/visualize_policy --env_name relocate-v0 --policy policies/relocate-v0.pickle --mode evaluation\n +''' + +# MAIN ========================================================= +@click.command(help=DESC) +@click.option('--env_name', type=str, help='environment to load', required= True) +@click.option('--num_trajs', type=int, help='Num trajectories', default=5000) +def main(env_name, num_trajs): + e = GymEnv(env_name) + # render policy + pol_playback(env_name, num_trajs) + +def pol_playback(env_name, num_trajs=100): + e = GymEnv(env_name) + e.reset() + + obs_ = [] + act_ = [] + rew_ = [] + term_ = [] + timeout_ = [] + info_qpos_ = [] + info_qvel_ = [] + info_env_state_ = [] + + ravg = [] + + for n in range(num_trajs): + e.reset() + returns = 0 + for t in range(e._horizon): + obs = e.get_obs() + obs_.append(obs) + info_qpos_.append(e.env.data.qpos.ravel().copy()) + info_qvel_.append(e.env.data.qvel.ravel().copy()) + info_env_state_.append(e.get_env_state()) + action = e.action_space.sample() + act_.append(action) + + _, rew, done, info = e.step(action) + returns += rew + rew_.append(rew) + + if t == (e._horizon-1): + timeout = True + done = False + else: + timeout = False + + term_.append(done) + timeout_.append(timeout) + + if done or timeout: + e.reset() + + #e.env.mj_render() # this is much faster + # e.render() + ravg.append(returns) + + # write out hdf5 file + obs_ = np.array(obs_).astype(np.float32) + act_ = np.array(act_).astype(np.float32) + rew_ = np.array(rew_).astype(np.float32) + term_ = np.array(term_).astype(np.bool_) + timeout_ = np.array(timeout_).astype(np.bool_) + info_qpos_ = np.array(info_qpos_).astype(np.float32) + info_qvel_ = np.array(info_qvel_).astype(np.float32) + + dataset = h5py.File('%s_random.hdf5' % env_name, 'w') + + #dataset.create_dataset('observations', obs_.shape, dtype='f4') + dataset.create_dataset('observations', data=obs_, compression='gzip') + dataset.create_dataset('actions', data=act_, compression='gzip') + dataset.create_dataset('rewards', data=rew_, compression='gzip') + dataset.create_dataset('terminals', data=term_, compression='gzip') + dataset.create_dataset('timeouts', data=timeout_, compression='gzip') + dataset.create_dataset('infos/qpos', data=info_qpos_, compression='gzip') + dataset.create_dataset('infos/qvel', data=info_qvel_, compression='gzip') + dataset.create_dataset('infos/env_state', data=np.array(info_env_state_, dtype=np.float32), compression='gzip') + +if __name__ == '__main__': + main() + diff --git a/d4rl/scripts/generation/mujoco/collect_data.py b/d4rl/scripts/generation/mujoco/collect_data.py new file mode 100644 index 0000000..ad529e4 --- /dev/null +++ b/d4rl/scripts/generation/mujoco/collect_data.py @@ -0,0 +1,169 @@ +import argparse +import re + +import h5py +import torch +import gym +import d4rl +import numpy as np + +from rlkit.torch import pytorch_util as ptu + +itr_re = re.compile(r'itr_(?P[0-9]+).pkl') + +def load(pklfile): + params = torch.load(pklfile) + return params['trainer/policy'] + +def get_pkl_itr(pklfile): + match = itr_re.search(pklfile) + if match: + return match.group('itr') + raise ValueError(pklfile+" has no iteration number.") + +def get_policy_wts(params): + out_dict = { + 'fc0/weight': params.fcs[0].weight.data.numpy(), + 'fc0/bias': params.fcs[0].bias.data.numpy(), + 'fc1/weight': params.fcs[1].weight.data.numpy(), + 'fc1/bias': params.fcs[1].bias.data.numpy(), + 'last_fc/weight': params.last_fc.weight.data.numpy(), + 'last_fc/bias': params.last_fc.bias.data.numpy(), + 'last_fc_log_std/weight': params.last_fc_log_std.weight.data.numpy(), + 'last_fc_log_std/bias': params.last_fc_log_std.bias.data.numpy(), + } + return out_dict + +def get_reset_data(): + data = dict( + observations = [], + next_observations = [], + actions = [], + rewards = [], + terminals = [], + timeouts = [], + logprobs = [], + qpos = [], + qvel = [] + ) + return data + +def rollout(policy, env_name, max_path, num_data, random=False): + env = gym.make(env_name) + + data = get_reset_data() + traj_data = get_reset_data() + + _returns = 0 + t = 0 + done = False + s = env.reset() + while len(data['rewards']) < num_data: + + + if random: + a = env.action_space.sample() + logprob = np.log(1.0 / np.prod(env.action_space.high - env.action_space.low)) + else: + torch_s = ptu.from_numpy(np.expand_dims(s, axis=0)) + distr = policy.forward(torch_s) + a = distr.sample() + logprob = distr.log_prob(a) + a = ptu.get_numpy(a).squeeze() + + #mujoco only + qpos, qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() + + try: + ns, rew, done, infos = env.step(a) + except: + print('lost connection') + env.close() + env = gym.make(env_name) + s = env.reset() + traj_data = get_reset_data() + t = 0 + _returns = 0 + continue + + _returns += rew + + t += 1 + timeout = False + terminal = False + if t == max_path: + timeout = True + elif done: + terminal = True + + + traj_data['observations'].append(s) + traj_data['actions'].append(a) + traj_data['next_observations'].append(ns) + traj_data['rewards'].append(rew) + traj_data['terminals'].append(terminal) + traj_data['timeouts'].append(timeout) + traj_data['logprobs'].append(logprob) + traj_data['qpos'].append(qpos) + traj_data['qvel'].append(qvel) + + s = ns + if terminal or timeout: + print('Finished trajectory. Len=%d, Returns=%f. Progress:%d/%d' % (t, _returns, len(data['rewards']), num_data)) + s = env.reset() + t = 0 + _returns = 0 + for k in data: + data[k].extend(traj_data[k]) + traj_data = get_reset_data() + + new_data = dict( + observations=np.array(data['observations']).astype(np.float32), + actions=np.array(data['actions']).astype(np.float32), + next_observations=np.array(data['next_observations']).astype(np.float32), + rewards=np.array(data['rewards']).astype(np.float32), + terminals=np.array(data['terminals']).astype(np.bool), + timeouts=np.array(data['timeouts']).astype(np.bool) + ) + new_data['infos/action_log_probs'] = np.array(data['logprobs']).astype(np.float32) + new_data['infos/qpos'] = np.array(data['qpos']).astype(np.float32) + new_data['infos/qvel'] = np.array(data['qvel']).astype(np.float32) + + for k in new_data: + new_data[k] = new_data[k][:num_data] + return new_data + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('env', type=str) + parser.add_argument('--pklfile', type=str, default=None) + parser.add_argument('--output_file', type=str, default='output.hdf5') + parser.add_argument('--max_path', type=int, default=1000) + parser.add_argument('--num_data', type=int, default=10000) + parser.add_argument('--random', action='store_true') + parser.add_argument('--seed', type=int, default=0) + args = parser.parse_args() + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + policy = None + if not args.random: + policy = load(args.pklfile) + data = rollout(policy, args.env, max_path=args.max_path, num_data=args.num_data, random=args.random) + + hfile = h5py.File(args.output_file, 'w') + for k in data: + hfile.create_dataset(k, data=data[k], compression='gzip') + + if args.random: + pass + else: + hfile['metadata/algorithm'] = np.string_('SAC') + hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0] + hfile['metadata/policy/nonlinearity'] = np.string_('relu') + hfile['metadata/policy/output_distribution'] = np.string_('tanh_gaussian') + for k, v in get_policy_wts(policy).items(): + hfile['metadata/policy/'+k] = v + hfile.close() diff --git a/d4rl/scripts/generation/mujoco/convert_buffer.py b/d4rl/scripts/generation/mujoco/convert_buffer.py new file mode 100644 index 0000000..bbbc545 --- /dev/null +++ b/d4rl/scripts/generation/mujoco/convert_buffer.py @@ -0,0 +1,46 @@ +import argparse +import re + +import h5py +import torch +import numpy as np + +itr_re = re.compile(r'itr_(?P[0-9]+).pkl') + +def load(pklfile): + params = torch.load(pklfile) + env_infos = params['replay_buffer/env_infos'] + results = { + 'observations': params['replay_buffer/observations'], + 'next_observations': params['replay_buffer/next_observations'], + 'actions': params['replay_buffer/actions'], + 'rewards': params['replay_buffer/rewards'], + 'terminals': env_infos['terminal'].squeeze(), + 'timeouts': env_infos['timeout'].squeeze(), + 'infos/action_log_probs': env_infos['action_log_prob'].squeeze(), + } + if 'qpos' in env_infos: + results['infos/qpos'] = env_infos['qpos'] + results['infos/qvel'] = env_infos['qvel'] + return results + +def get_pkl_itr(pklfile): + match = itr_re.search(pklfile) + if match: + return match.group('itr') + raise ValueError(pklfile+" has no iteration number.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('pklfile', type=str) + parser.add_argument('--output_file', type=str, default='output.hdf5') + args = parser.parse_args() + + data = load(args.pklfile) + hfile = h5py.File(args.output_file, 'w') + for k in data: + hfile.create_dataset(k, data=data[k], compression='gzip') + hfile['metadata/algorithm'] = np.string_('SAC') + hfile['metadata/iteration'] = np.array([get_pkl_itr(args.pklfile)], dtype=np.int32)[0] + hfile.close() diff --git a/d4rl/scripts/generation/mujoco/fix_qpos_qvel.py b/d4rl/scripts/generation/mujoco/fix_qpos_qvel.py new file mode 100644 index 0000000..4607d58 --- /dev/null +++ b/d4rl/scripts/generation/mujoco/fix_qpos_qvel.py @@ -0,0 +1,124 @@ +import numpy as np +import argparse +import d4rl +import d4rl.offline_env +import gym +import h5py +import os + +def unwrap_env(env): + return env.env.wrapped_env + +def set_state_qpos(env, qpos, qvel): + env.set_state(qpos, qvel) + +def pad_obs(env, obs, twod=False, scale=0.1): + #TODO: sample val + if twod: + val = env.init_qpos[0:2] + np.random.uniform(size=2, low=-.1, high=.1) + state = np.concatenate([np.ones(2)*val, obs]) + else: + val = env.init_qpos[0:1] + np.random.uniform(size=1, low=-scale, high=scale) + state = np.concatenate([np.ones(1)*val, obs]) + return state + +def set_state_obs(env, obs): + env_name = (str(unwrap_env(env).__class__)) + ant_env = 'Ant' in env_name + hopper_walker_env = 'Hopper' in env_name or 'Walker' in env_name + state = pad_obs(env, obs, twod=ant_env, scale=0.005 if hopper_walker_env else 0.1) + qpos_dim = env.sim.data.qpos.size + if ant_env: + env.set_state(state[:15], state[15:29]) + else: + env.set_state(state[:qpos_dim], state[qpos_dim:]) + + +def resync_state_obs(env, obs): + # Prevents drifting of the obs over time + ant_env = 'Ant' in (str(unwrap_env(env).__class__)) + cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() + if ant_env: + cur_qpos[2:15] = obs[0:13] + cur_qvel = obs[13:27] + env.set_state(cur_qpos, cur_qvel) + else: + qpos_dim = env.sim.data.qpos.size + cur_qpos[1:] = obs[0:qpos_dim-1] + cur_qvel = obs[qpos_dim-1:] + env.set_state(cur_qpos, cur_qvel) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('env', type=str) + args = parser.parse_args() + + env = gym.make(args.env) + env.reset() + + fname = unwrap_env(env).dataset_url.split('/')[-1] + prefix, ext = os.path.splitext(fname) + #out_fname = prefix+'_qfix'+ext + out_fname = prefix+ext + + dset = env.get_dataset() + all_qpos = dset['infos/qpos'] + all_qvel = dset['infos/qvel'] + observations = dset['observations'] + actions = dset['actions'] + dones = dset['terminals'] + timeouts = dset['timeouts'] + terminals = dones + timeouts + + start_obs = observations[0] + set_state_obs(env, start_obs) + #set_state_qpos(env, all_qpos[0], all_qvel[0]) + + new_qpos = [] + new_qvel = [] + + for t in range(actions.shape[0]): + cur_qpos, cur_qvel = env.sim.data.qpos.ravel().copy(), env.sim.data.qvel.ravel().copy() + new_qpos.append(cur_qpos) + new_qvel.append(cur_qvel) + + next_obs, reward, done, infos = env.step(actions[t]) + + if t == actions.shape[0]-1: + break + if terminals[t]: + set_state_obs(env, observations[t+1]) + #print(t, 'done') + else: + true_next_obs = observations[t+1] + error = ((true_next_obs - next_obs)**2).sum() + if t % 1000 == 0: + print(t, error) + + # prevent drifting over time + resync_state_obs(env, observations[t+1]) + + dset_filepath = d4rl.offline_env.download_dataset_from_url(unwrap_env(env).dataset_url) + inf = h5py.File(dset_filepath, 'r') + outf = h5py.File(out_fname, 'w') + + for k in d4rl.offline_env.get_keys(inf): + print('writing', k) + if 'qpos' in k: + outf.create_dataset(k, data=np.array(new_qpos), compression='gzip') + elif 'qvel' in k: + outf.create_dataset(k, data=np.array(new_qvel), compression='gzip') + else: + try: + if 'reward' in k: + outf.create_dataset(k, data=inf[k][:].squeeze().astype(np.float32), compression='gzip') + else: + if 'terminals' in k or 'timeouts' in k: + outf.create_dataset(k, data=inf[k][:].astype(np.bool), compression='gzip') + else: + outf.create_dataset(k, data=inf[k][:].astype(np.float32), compression='gzip') + except Exception as e: + print(e) + outf.create_dataset(k, data=inf[k]) + outf.close() diff --git a/d4rl/scripts/generation/mujoco/stitch_dataset.py b/d4rl/scripts/generation/mujoco/stitch_dataset.py new file mode 100644 index 0000000..ad8146e --- /dev/null +++ b/d4rl/scripts/generation/mujoco/stitch_dataset.py @@ -0,0 +1,37 @@ +import argparse +import h5py +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('file1', type=str, default=None) + parser.add_argument('file2', type=str, default=None) + parser.add_argument('--output_file', type=str, default='output.hdf5') + parser.add_argument('--maxlen', type=int, default=2000000) + args = parser.parse_args() + + hfile1 = h5py.File(args.file1, 'r') + hfile2 = h5py.File(args.file2, 'r') + outf = h5py.File(args.output_file, 'w') + + keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals', 'timeouts', 'infos/action_log_probs', 'infos/qpos', 'infos/qvel'] + # be careful with trajectories not ending at the end of a file! + + # find end of last traj + terms = hfile1['terminals'][:] + tos = hfile1['timeouts'][:] + last_term = 0 + for i in range(terms.shape[0]-1, -1, -1): + if terms[i] or tos[i]: + last_term = i + break + N = last_term + 1 + + for k in keys: + d1 = hfile1[k][:N] + d2 = hfile2[k][:] + combined = np.concatenate([d1,d2],axis=0)[:args.maxlen] + print(k, combined.shape) + outf.create_dataset(k, data=combined, compression='gzip') + + outf.close() diff --git a/d4rl/scripts/generation/relabel_antmaze_rewards.py b/d4rl/scripts/generation/relabel_antmaze_rewards.py new file mode 100644 index 0000000..c86c279 --- /dev/null +++ b/d4rl/scripts/generation/relabel_antmaze_rewards.py @@ -0,0 +1,49 @@ +import d4rl.locomotion +from d4rl.offline_env import get_keys +import os +import argparse +import numpy as np +import gym +import h5py + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', default='antmaze-umaze-v0', help='') + parser.add_argument('--relabel_type', default='sparse', help='') + parser.add_argument('--filename', type=str) + args = parser.parse_args() + + env = gym.make(args.env_name) + target_goal = env.target_goal + print ('Target Goal: ', target_goal) + + rdataset = h5py.File(args.filename, 'r') + fpath, ext = os.path.splitext(args.filename) + wdataset = h5py.File(fpath + '_' + args.relabel_type + ext, 'w') + + all_obs = rdataset['observations'][:] + + if args.relabel_type == 'dense': + """reward at the next state = dist(s', g)""" + _rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1)) + elif args.relabel_type == 'sparse': + _rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) + else: + _rew = rdataset['rewards'][:] + + # Also add terminals here + _terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) + _terminals = np.concatenate([_terminals, np.array([0])], 0) + _rew = np.concatenate([_rew, np.array([0])], 0) + print ('Sum of rewards: ', _rew.sum()) + + for k in get_keys(rdataset): + print(k) + if k == 'rewards': + wdataset.create_dataset(k, data=_rew, compression='gzip') + elif k == 'terminals': + wdataset.create_dataset(k, data=_terminals, compression='gzip') + else: + wdataset.create_dataset(k, data=rdataset[k], compression='gzip') + diff --git a/d4rl/scripts/generation/relabel_maze2d_rewards.py b/d4rl/scripts/generation/relabel_maze2d_rewards.py new file mode 100644 index 0000000..76676ec --- /dev/null +++ b/d4rl/scripts/generation/relabel_maze2d_rewards.py @@ -0,0 +1,49 @@ +from d4rl.pointmaze import MazeEnv, maze_model +from d4rl.offline_env import get_keys +import os +import argparse +import numpy as np +import h5py + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='SAC-BEAR') + parser.add_argument('--maze', default='umaze', help='') + parser.add_argument('--relabel_type', default='dense', help='') + parser.add_argument('--filename', type=str) + args = parser.parse_args() + + + if args.maze == 'umaze': + maze = maze_model.U_MAZE + elif args.maze == 'open': + maze = maze_model.OPEN + elif args.maze == 'medium': + maze = maze_model.MEDIUM_MAZE + else: + maze = maze_model.LARGE_MAZE + env = MazeEnv(maze, reset_target=False, reward_type='sparse') + target_goal = env._target + + rdataset = h5py.File(args.filename, 'r') + fpath, ext = os.path.splitext(args.filename) + wdataset = h5py.File(fpath+'-'+args.relabel_type+ext, 'w') + + all_obs = rdataset['observations'] + if args.relabel_type == 'dense': + _rew = np.exp(-np.linalg.norm(all_obs[:,:2] - target_goal, axis=1)) + elif args.relabel_type == 'sparse': + _rew = (np.linalg.norm(all_obs[:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32) + else: + _rew = rdataset['rewards'].value + + for k in get_keys(rdataset): + print(k) + if k == 'rewards': + wdataset.create_dataset(k, data=_rew, compression='gzip') + else: + if k.startswith('metadata'): + wdataset[k] = rdataset[k][()] + else: + wdataset.create_dataset(k, data=rdataset[k], compression='gzip') + diff --git a/d4rl/scripts/ope_rollout.py b/d4rl/scripts/ope_rollout.py new file mode 100644 index 0000000..565b8ea --- /dev/null +++ b/d4rl/scripts/ope_rollout.py @@ -0,0 +1,36 @@ +""" +This script runs rollouts on the OPE policies +using the ONNX runtime and averages the returns. +""" +import d4rl +import gym +import sys +import onnx +import onnxruntime as ort +import numpy as np +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('policy', type=str, help='ONNX policy file. i.e. cheetah.sampler.onnx') +parser.add_argument('env_name', type=str, help='Env name') +parser.add_argument('--num_rollouts', type=int, default=10, help='Number of rollouts to run.') +args = parser.parse_args() + +env = gym.make(args.env_name) + +policy = ort.InferenceSession(args.policy) + +all_returns = [] +for _ in range(args.num_rollouts): + s = env.reset() + returns = 0 + for t in range(env._max_episode_steps): + obs_input = np.expand_dims(s, axis=0).astype(np.float32) + noise_input = np.random.randn(1, env.action_space.shape[0]).astype(np.float32) + action, _, _ = policy.run(None, {'observations': obs_input, 'noise': noise_input}) + s, r, d, _ = env.step(action) + returns += r + print(returns, end='\r') + all_returns.append(returns) +print(args.env_name, ':', np.mean(returns)) + diff --git a/d4rl/scripts/reference_scores/adroit_expert.py b/d4rl/scripts/reference_scores/adroit_expert.py new file mode 100644 index 0000000..61c1520 --- /dev/null +++ b/d4rl/scripts/reference_scores/adroit_expert.py @@ -0,0 +1,48 @@ +""" +Instructions: + +1) Download the expert policies from https://github.com/aravindr93/hand_dapg +2) Place the policies from dapg_policies in the current directory +3) Run this script passing in the appropriate env_name +""" +import d4rl +import argparse +import os +import gym +import numpy as np +import pickle +from mjrl.utils.gym_env import GymEnv + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', default='', help='Environment Name') + parser.add_argument('--num_episodes', type=int, default=100) + args = parser.parse_args() + + policy = './policies/'+args.env_name+'.pickle' + pi = pickle.load(open(policy, 'rb')) + e = gym.make(args.env_name) + e.seed(0) + e.reset() + + ravg = [] + for n in range(args.num_episodes): + e.reset() + returns = 0 + for t in range(e._max_episode_steps): + obs = e.get_obs() + action, infos = pi.get_action(obs) + action = pi.get_action(obs)[0] # eval + _, rew, done, info = e.step(action) + returns += rew + if done: + break + # e.env.mj_render() # this is much faster + # e.render() + ravg.append(returns) + print(args.env_name, 'returns', np.mean(ravg)) + + +if __name__ == '__main__': + main() + diff --git a/d4rl/scripts/reference_scores/carla_lane_controller.py b/d4rl/scripts/reference_scores/carla_lane_controller.py new file mode 100644 index 0000000..8eb0350 --- /dev/null +++ b/d4rl/scripts/reference_scores/carla_lane_controller.py @@ -0,0 +1,37 @@ +import d4rl +import gym +from d4rl.carla import data_collection_agent_lane +import numpy as np +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', type=str, default='carla-lane-v0', help='Maze type. small or default') + parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') + args = parser.parse_args() + + env = gym.make(args.env_name) + env.seed(0) + np.random.seed(0) + + ravg = [] + for i in range(args.num_episodes): + s = env.reset() + controller = data_collection_agent_lane.RoamingAgent(env) + returns = 0 + for t in range(env._max_episode_steps): + act = controller.compute_action() + + s, rew, done, _ = env.step(act) + returns += rew + if done: + break + ravg.append(returns) + print(i, returns, ' mean:', np.mean(ravg)) + print(args.env_name, 'returns', np.mean(ravg)) + + +if __name__ == "__main__": + main() + diff --git a/d4rl/scripts/reference_scores/generate_ref_min_score.py b/d4rl/scripts/reference_scores/generate_ref_min_score.py new file mode 100644 index 0000000..f0d9ea5 --- /dev/null +++ b/d4rl/scripts/reference_scores/generate_ref_min_score.py @@ -0,0 +1,38 @@ +""" +Generate "minimum" reference scores by averaging the score for a random +policy over 100 episodes. +""" +import d4rl +import argparse +import gym +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', default='', help='Environment Name') + parser.add_argument('--num_episodes', type=int, default=100) + args = parser.parse_args() + + env = gym.make(args.env_name) + env.seed(0) + try: + env.action_space.seed(0) + except: + pass + + ravg = [] + for n in range(args.num_episodes): + env.reset() + returns = 0 + for t in range(env._max_episode_steps): + action = env.action_space.sample() + _, rew, done, info = env.step(action) + returns += rew + if done: + break + ravg.append(returns) + print('%s Average returns (%d ep): %f' % (args.env_name, args.num_episodes, np.mean(ravg))) + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/reference_scores/generate_ref_min_score.sh b/d4rl/scripts/reference_scores/generate_ref_min_score.sh new file mode 100755 index 0000000..bb8f07b --- /dev/null +++ b/d4rl/scripts/reference_scores/generate_ref_min_score.sh @@ -0,0 +1,5 @@ +for e in $(cat scripts/reference_scores/envs.txt) +do + python scripts/reference_scores/generate_ref_min_score.py --env_name=$e +done + diff --git a/d4rl/scripts/reference_scores/maze2d_bullet_controller.py b/d4rl/scripts/reference_scores/maze2d_bullet_controller.py new file mode 100644 index 0000000..2d8f03e --- /dev/null +++ b/d4rl/scripts/reference_scores/maze2d_bullet_controller.py @@ -0,0 +1,49 @@ +import d4rl +import gym +from d4rl.pointmaze import waypoint_controller +from d4rl.pointmaze import maze_model +import numpy as np +import argparse +import time + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default') + parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') + parser.add_argument('--render', action='store_true') + args = parser.parse_args() + + env = gym.make(args.env_name) + if args.render: + env.render('human') + env.seed(0) + np.random.seed(0) + d_gain = -2.0 + p_gain = 10.0 + controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain) + print('max steps:', env._max_episode_steps) + + ravg = [] + for _ in range(args.num_episodes): + controller = waypoint_controller.WaypointController(env.env.str_maze_spec, p_gain=p_gain, d_gain=d_gain) + s = env.reset() + returns = 0 + for t in range(env._max_episode_steps): + position = s[0:2] + velocity = s[2:4] + act, done = controller.get_action(position, velocity, np.array(env.env.get_target())) + #print(position-1, controller.current_waypoint(), np.array(env.env.get_target()) - 1) + #print('\t', act) + s, rew, _, _ = env.step(act) + if args.render: + time.sleep(0.01) + env.render('human') + returns += rew + print(returns) + ravg.append(returns) + print(args.env_name, 'returns', np.mean(ravg)) + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/reference_scores/maze2d_controller.py b/d4rl/scripts/reference_scores/maze2d_controller.py new file mode 100644 index 0000000..49f93f5 --- /dev/null +++ b/d4rl/scripts/reference_scores/maze2d_controller.py @@ -0,0 +1,35 @@ +import d4rl +import gym +from d4rl.pointmaze import waypoint_controller +from d4rl.pointmaze import maze_model +import numpy as np +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0', help='Maze type. small or default') + parser.add_argument('--num_episodes', type=int, default=100, help='Num samples to collect') + args = parser.parse_args() + + env = gym.make(args.env_name) + env.seed(0) + np.random.seed(0) + controller = waypoint_controller.WaypointController(env.str_maze_spec) + + ravg = [] + for _ in range(args.num_episodes): + s = env.reset() + returns = 0 + for t in range(env._max_episode_steps): + position = s[0:2] + velocity = s[2:4] + act, done = controller.get_action(position, velocity, env.get_target()) + s, rew, _, _ = env.step(act) + returns += rew + ravg.append(returns) + print(args.env_name, 'returns', np.mean(ravg)) + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/reference_scores/minigrid_controller.py b/d4rl/scripts/reference_scores/minigrid_controller.py new file mode 100644 index 0000000..6b115fb --- /dev/null +++ b/d4rl/scripts/reference_scores/minigrid_controller.py @@ -0,0 +1,36 @@ +import logging +from offline_rl.gym_minigrid import fourroom_controller +from offline_rl.gym_minigrid.envs import fourrooms +import numpy as np +import pickle +import gzip +import h5py +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--num_episodes', type=int, default=100, help='Num trajs to collect') + args = parser.parse_args() + + np.random.seed(0) + + env = fourrooms.FourRoomsEnv() + env.seed(0) + controller = fourroom_controller.FourRoomController() + controller.set_target(env.get_target()) + + ravg = [] + for _ in range(args.num_episodes): + s = env.reset() + returns = 0 + for t in range(50): + act, done = controller.get_action(env.agent_pos, env.agent_dir) + ns, rew, _, _ = env.step(act) + returns += rew + ravg.append(returns) + print('returns', np.mean(ravg)) + + +if __name__ == "__main__": + main() diff --git a/d4rl/scripts/visualize_dataset.py b/d4rl/scripts/visualize_dataset.py new file mode 100644 index 0000000..b6a2202 --- /dev/null +++ b/d4rl/scripts/visualize_dataset.py @@ -0,0 +1,25 @@ +import argparse +import d4rl +import gym + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--env_name', type=str, default='maze2d-umaze-v0') + args = parser.parse_args() + + env = gym.make(args.env_name) + + dataset = env.get_dataset() + if 'infos/qpos' not in dataset: + raise ValueError('Only MuJoCo-based environments can be visualized') + qpos = dataset['infos/qpos'] + qvel = dataset['infos/qvel'] + rewards = dataset['rewards'] + actions = dataset['actions'] + + env.reset() + env.set_state(qpos[0], qvel[0]) + for t in range(qpos.shape[0]): + env.set_state(qpos[t], qvel[t]) + env.render() diff --git a/d4rl/setup.py b/d4rl/setup.py new file mode 100644 index 0000000..a3d0ab5 --- /dev/null +++ b/d4rl/setup.py @@ -0,0 +1,30 @@ +from distutils.core import setup +from platform import platform + +from setuptools import find_packages + +setup( + name='d4rl', + version='1.1', + install_requires=['gym', + 'numpy', + 'mujoco_py', + 'pybullet', + 'h5py', + 'termcolor', # adept_envs dependency + 'click', # adept_envs dependency + 'dm_control' if 'macOS' in platform() else + 'dm_control @ git+https://github.com/deepmind/dm_control@main#egg=dm_control', + #'mjrl @ git+https://github.com/aravindr93/mjrl@master#egg=mjrl' + ], + packages=find_packages(), + package_data={'d4rl': ['locomotion/assets/*', + 'hand_manipulation_suite/assets/*', + 'hand_manipulation_suite/Adroit/*', + 'hand_manipulation_suite/Adroit/gallery/*', + 'hand_manipulation_suite/Adroit/resources/*', + 'hand_manipulation_suite/Adroit/resources/meshes/*', + 'hand_manipulation_suite/Adroit/resources/textures/*', + ]}, + include_package_data=True, +) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..ead6186 --- /dev/null +++ b/environment.yml @@ -0,0 +1,16 @@ +name: decision-transformer +dependencies: +- conda-forge::cudatoolkit=11.3.1 +- python=3.9.7 +- pytorch::pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0 +- scipy=1.7.3 +- pytorch::torchvision=0.11.3=py39_cu113 +- anaconda +- numpy +- pip +- pip: + - gym==0.18.3 + - mujoco-py + - transformers==4.5.1 + - wandb + - matplotlib diff --git a/gym/conda_env.yml b/gym/conda_env.yml new file mode 100644 index 0000000..ead6186 --- /dev/null +++ b/gym/conda_env.yml @@ -0,0 +1,16 @@ +name: decision-transformer +dependencies: +- conda-forge::cudatoolkit=11.3.1 +- python=3.9.7 +- pytorch::pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0 +- scipy=1.7.3 +- pytorch::torchvision=0.11.3=py39_cu113 +- anaconda +- numpy +- pip +- pip: + - gym==0.18.3 + - mujoco-py + - transformers==4.5.1 + - wandb + - matplotlib diff --git a/gym/data/download_d4rl_datasets.py b/gym/data/download_d4rl_datasets.py new file mode 100644 index 0000000..5751cee --- /dev/null +++ b/gym/data/download_d4rl_datasets.py @@ -0,0 +1,50 @@ +import gym +import numpy as np + +import collections +import pickle + +import d4rl + + +datasets = [] + +for env_name in ['halfcheetah', 'hopper', 'walker2d']: + for dataset_type in ['medium', 'medium-replay', 'expert']: + name = f'{env_name}-{dataset_type}-v2' + env = gym.make(name) + dataset = env.get_dataset() + + N = dataset['rewards'].shape[0] + data_ = collections.defaultdict(list) + + use_timeouts = False + if 'timeouts' in dataset: + use_timeouts = True + + episode_step = 0 + paths = [] + for i in range(N): + done_bool = bool(dataset['terminals'][i]) + if use_timeouts: + final_timestep = dataset['timeouts'][i] + else: + final_timestep = (episode_step == 1000-1) + for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: + data_[k].append(dataset[k][i]) + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + paths.append(episode_data) + data_ = collections.defaultdict(list) + episode_step += 1 + + returns = np.array([np.sum(p['rewards']) for p in paths]) + num_samples = np.sum([p['rewards'].shape[0] for p in paths]) + print(f'Number of samples collected: {num_samples}') + print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') + + with open(f'{name}.pkl', 'wb') as f: + pickle.dump(paths, f) diff --git a/gym/decision_transformer/envs/assets/reacher_2d.xml b/gym/decision_transformer/envs/assets/reacher_2d.xml new file mode 100644 index 0000000..8f988b7 --- /dev/null +++ b/gym/decision_transformer/envs/assets/reacher_2d.xml @@ -0,0 +1,33 @@ + + + + + + + + diff --git a/gym/decision_transformer/envs/reacher_2d.py b/gym/decision_transformer/envs/reacher_2d.py new file mode 100644 index 0000000..202cb41 --- /dev/null +++ b/gym/decision_transformer/envs/reacher_2d.py @@ -0,0 +1,62 @@ +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + +import os + + +class Reacher2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): + + def __init__(self): + self.fingertip_sid = 0 + self.target_bid = 0 + curr_dir = os.path.dirname(os.path.abspath(__file__)) + mujoco_env.MujocoEnv.__init__(self, curr_dir+'/assets/reacher_2d.xml', 15) + self.fingertip_sid = self.sim.model.site_name2id('fingertip') + self.target_bid = self.sim.model.body_name2id('target') + utils.EzPickle.__init__(self) + + def step(self, action): + action = np.clip(action, -1.0, 1.0) + self.do_simulation(action, self.frame_skip) + tip = self.data.site_xpos[self.fingertip_sid][:2] + tar = self.data.body_xpos[self.target_bid][:2] + dist = np.sum(np.abs(tip - tar)) + reward_dist = 0. # - 0.1 * dist + reward_ctrl = 0.0 + reward_bonus = 1.0 if dist < 0.1 else 0.0 + reward = reward_bonus + reward_ctrl + reward_dist + done = False + ob = self._get_obs() + return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, reward_bonus=reward_bonus) + + def _get_obs(self): + theta = self.data.qpos.ravel() + tip = self.data.site_xpos[self.fingertip_sid][:2] + tar = self.data.body_xpos[self.target_bid][:2] + return np.concatenate([ + # self.data.qpos.flat, + np.sin(theta), + np.cos(theta), + self.dt * self.data.qvel.ravel(), + tip, + tar, + tip-tar, + ]) + + def reset_model(self): + # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos + # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + qpos = self.np_random.uniform(low=-2.0, high=2.0, size=self.model.nq) + qvel = self.init_qvel * 0.0 + while True: + self.goal = self.np_random.uniform(low=-1.5, high=1.5, size=2) + if np.linalg.norm(self.goal) <= 1.0 and np.linalg.norm(self.goal) >= 0.5: + break + self.set_state(qpos, qvel) + self.model.body_pos[self.target_bid][:2] = self.goal + self.sim.forward() + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.distance = self.model.stat.extent * 5.0 diff --git a/gym/decision_transformer/evaluation/evaluate_episodes.py b/gym/decision_transformer/evaluation/evaluate_episodes.py new file mode 100644 index 0000000..7698196 --- /dev/null +++ b/gym/decision_transformer/evaluation/evaluate_episodes.py @@ -0,0 +1,153 @@ +import numpy as np +import torch + + +def evaluate_episode( + env, + state_dim, + act_dim, + model, + max_ep_len=1000, + device='cuda', + target_return=None, + mode='normal', + state_mean=0., + state_std=1., +): + + model.eval() + model.to(device=device) + + state_mean = torch.from_numpy(state_mean).to(device=device) + state_std = torch.from_numpy(state_std).to(device=device) + + state = env.reset() + + # we keep all the histories on the device + # note that the latest action and reward will be "padding" + states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) + actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) + rewards = torch.zeros(0, device=device, dtype=torch.float32) + target_return = torch.tensor(target_return, device=device, dtype=torch.float32) + sim_states = [] + + episode_return, episode_length = 0, 0 + for t in range(max_ep_len): + + # add padding + actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) + rewards = torch.cat([rewards, torch.zeros(1, device=device)]) + + action = model.get_action( + (states.to(dtype=torch.float32) - state_mean) / state_std, + actions.to(dtype=torch.float32), + rewards.to(dtype=torch.float32), + target_return=target_return, + ) + actions[-1] = action + action = action.detach().cpu().numpy() + + state, reward, done, _ = env.step(action) + + cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) + states = torch.cat([states, cur_state], dim=0) + rewards[-1] = reward + + episode_return += reward + episode_length += 1 + + if done: + break + + return episode_return, episode_length + + +def evaluate_episode_rtg( + env, + state_dim, + act_dim, + model, + max_ep_len=1000, + scale=1000., + state_mean=0., + state_std=1., + device='cuda', + target_return=None, + mode='normal', + use_means=False, + return_traj=False, + eval_context=None + ): + + model.eval() + model.to(device=device) + + state_mean = torch.from_numpy(state_mean).to(device=device) + state_std = torch.from_numpy(state_std).to(device=device) + + state = env.reset() + if mode == 'noise': + state = state + np.random.normal(0, 0.1, size=state.shape) + + # we keep all the histories on the device + # note that the latest action and reward will be "padding" + states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) + actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) + rewards = torch.zeros(0, device=device, dtype=torch.float32) + + ep_return = target_return + target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1) + timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1) + + sim_states = [] + + episode_return, episode_length = 0, 0 + for t in range(max_ep_len): + + + actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0) + rewards = torch.cat([rewards, torch.zeros(1, device=device)]) + action = model.get_action( + (states.to(dtype=torch.float32) - state_mean) / state_std, + actions.to(dtype=torch.float32), + rewards.to(dtype=torch.float32), + target_return.to(dtype=torch.float32), + timesteps.to(dtype=torch.long), + use_means=use_means, + custom_max_length=eval_context + ) + actions[-1] = action + action = action.detach().cpu().numpy() + + state, reward, done, _ = env.step(action) + + cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim) + states = torch.cat([states, cur_state], dim=0) + rewards[-1] = reward + + if mode != 'delayed': + pred_return = target_return[0,-1] - (reward/scale) + else: + pred_return = target_return[0,-1] + target_return = torch.cat( + [target_return, pred_return.reshape(1, 1)], dim=1) + timesteps = torch.cat( + [timesteps, + torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1) + + episode_return += reward + episode_length += 1 + + if done: + break + + if return_traj: + traj = { + 'observations': states[:-1].cpu().detach().numpy(), + 'actions': actions.cpu().detach().numpy(), + 'rewards': rewards.cpu().detach().numpy(), + 'terminals': np.zeros(episode_length, dtype=bool) + } + return episode_return, episode_length, traj + else: + return episode_return, episode_length diff --git a/gym/decision_transformer/models/decision_transformer.py b/gym/decision_transformer/models/decision_transformer.py new file mode 100644 index 0000000..457439c --- /dev/null +++ b/gym/decision_transformer/models/decision_transformer.py @@ -0,0 +1,226 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.distributions import Normal, Independent +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import TanhTransform + +import transformers + +from decision_transformer.models.model import TrajectoryModel +from decision_transformer.models.trajectory_gpt2 import GPT2Model + + +class DecisionTransformer(TrajectoryModel): + + """ + This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...) + """ + + def __init__( + self, + state_dim, + act_dim, + hidden_size, + max_length=None, + max_ep_len=4096, + action_tanh=True, + stochastic=False, + log_std_min=-20, + log_std_max=2, + remove_pos_embs=False, + stochastic_tanh=False, + approximate_entropy_samples=1000, + **kwargs + ): + super().__init__(state_dim, act_dim, max_length=max_length) + + self.hidden_size = hidden_size + config = transformers.GPT2Config( + vocab_size=1, # doesn't matter -- we don't use the vocab + n_embd=hidden_size, + **kwargs + ) + + # note: the only difference between this GPT2Model and the default Huggingface version + # is that the positional embeddings are removed (since we'll add those ourselves) + self.transformer = GPT2Model(config) + + # Settings from stochastic actions + self.stochastic = stochastic + self.log_std_min=log_std_min + self.log_std_max=log_std_max + self.stochastic_tanh=stochastic_tanh + self.approximate_entropy_samples=approximate_entropy_samples + + + + self.remove_pos_embs = remove_pos_embs + if not remove_pos_embs: + self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) + self.embed_return = torch.nn.Linear(1, hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, hidden_size) + self.embed_action = torch.nn.Linear(self.act_dim, hidden_size) + + self.embed_ln = nn.LayerNorm(hidden_size) + + self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) + + if stochastic: + self.predict_action_mean = nn.Sequential( + nn.Linear(hidden_size, self.act_dim), + ) + self.predict_action_logstd = nn.Sequential( + nn.Linear(hidden_size, self.act_dim), + ) + else: + self.predict_action = nn.Sequential( + *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])) + ) + self.predict_return = torch.nn.Linear(hidden_size, 1) + + def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None, target_actions=None, use_means=False): + + batch_size, seq_length = states.shape[0], states.shape[1] + + transition_size = 3 + + if attention_mask is None: + # attention mask for GPT: 1 if can be attended to, 0 if not + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long, device=states.device) + + # embed each modality with a different head + state_embeddings = self.embed_state(states) + action_embeddings = self.embed_action(actions) + returns_embeddings = self.embed_return(returns_to_go) + + # Optionally can remove, may be better for certain domains if order can be inferred by return seq + if not self.remove_pos_embs: + time_embeddings = self.embed_timestep(timesteps) + + # time embeddings are treated similar to positional embeddings + state_embeddings = state_embeddings + time_embeddings + action_embeddings = action_embeddings + time_embeddings + returns_embeddings = returns_embeddings + time_embeddings + + # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) + # which works nice in an autoregressive sense since states predict actions + embeddings = (returns_embeddings, state_embeddings, action_embeddings) + stacked_inputs = torch.stack( + embeddings, dim=1 + ).permute(0, 2, 1, 3).reshape(batch_size, transition_size*seq_length, self.hidden_size) + stacked_inputs = self.embed_ln(stacked_inputs) + + # to make the attention mask fit the stacked inputs, have to stack it as well + attention_masks = (attention_mask, attention_mask, attention_mask) + + stacked_attention_mask = torch.stack( + attention_masks, dim=1 + ).permute(0, 2, 1).reshape(batch_size, transition_size*seq_length) + + # we feed in the input embeddings (not word indices as in NLP) to the model + transformer_outputs = self.transformer( + inputs_embeds=stacked_inputs, + attention_mask=stacked_attention_mask, + use_cache=False + ) + x = transformer_outputs['last_hidden_state'] + + # reshape x so that the second dimension corresponds to the original + # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t + # or rewards (3) + x = x.reshape(batch_size, seq_length, transition_size, self.hidden_size).permute(0, 2, 1, 3) + + state_reps = x[:,1] + action_reps = x[:,2] + + # get predictions + return_preds = self.predict_return(action_reps) # predict next return given state and action + state_preds = self.predict_state(action_reps) # predict next state given state and action + + + action_log_probs = None + entropies = None + if self.stochastic: + + means = self.predict_action_mean(state_reps) + log_stds = self.predict_action_logstd(state_reps) + + # Bound log of standard deviations + log_stds = torch.clamp(log_stds, self.log_std_min, self.log_std_max) + stds = torch.exp(log_stds) + + #action_distributions = TransformedDistribution(Normal(means, stds), TanhTransform(cache_size=1)) + #action_distributions = Normal(means, stds) + + if self.stochastic_tanh: + action_distributions = Independent(TransformedDistribution(Normal(means, stds), TanhTransform(cache_size=1)),1) + else: + action_distributions = Independent(Normal(means, stds),1) + # Sample from distribution or predict mean + if use_means: + if self.stochastic_tanh: + action_preds = torch.tanh(action_distributions.base_dist.base_dist.mean) + else: + action_preds = action_distributions.mean + else: + action_preds = action_distributions.rsample() + + if target_actions != None: + # Clamp target actions to prevent nans + eps = torch.finfo(target_actions.dtype).eps + target_actions = torch.clamp(target_actions, -1+eps, 1-eps) + action_log_probs = action_distributions.log_prob(target_actions) + #entropies = action_distributions.base_dist.entropy() + if self.stochastic_tanh: + entropies = -action_distributions.log_prob(action_distributions.rsample(sample_shape=torch.Size([self.approximate_entropy_samples]))).mean(dim=0) + else: + entropies = action_distributions.entropy() + + + else: + action_preds = self.predict_action(x[:,1]) # predict next action given state + + return state_preds, action_preds, return_preds, action_log_probs, entropies + + def get_action(self, states, actions, rewards, returns_to_go, timesteps, use_means=False, custom_max_length=None,**kwargs): + # we don't care about the past rewards in this model + + states = states.reshape(1, -1, self.state_dim) + actions = actions.reshape(1, -1, self.act_dim) + returns_to_go = returns_to_go.reshape(1, -1, 1) + timesteps = timesteps.reshape(1, -1) + + max_length = self.max_length + if custom_max_length is not None: + max_length = custom_max_length + if max_length is not None: + states = states[:,-max_length:] + actions = actions[:,-max_length:] + returns_to_go = returns_to_go[:,-max_length:] + timesteps = timesteps[:,-max_length:] + + # pad all tokens to sequence length + attention_mask = torch.cat([torch.zeros(max_length-states.shape[1]), torch.ones(states.shape[1])]) + attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) + states = torch.cat( + [torch.zeros((states.shape[0], max_length-states.shape[1], self.state_dim), device=states.device), states], + dim=1).to(dtype=torch.float32) + actions = torch.cat( + [torch.zeros((actions.shape[0], max_length - actions.shape[1], self.act_dim), + device=actions.device), actions], + dim=1).to(dtype=torch.float32) + returns_to_go = torch.cat( + [torch.zeros((returns_to_go.shape[0], max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], + dim=1).to(dtype=torch.float32) + timesteps = torch.cat( + [torch.zeros((timesteps.shape[0], max_length-timesteps.shape[1]), device=timesteps.device), timesteps], + dim=1 + ).to(dtype=torch.long) + else: + attention_mask = None + + state_preds, action_preds, return_preds, _, _ = self.forward( + states, actions, rewards, returns_to_go, timesteps, attention_mask=attention_mask, use_means=use_means, **kwargs) + return action_preds[0,-1] + diff --git a/gym/decision_transformer/models/mlp_bc.py b/gym/decision_transformer/models/mlp_bc.py new file mode 100644 index 0000000..7459733 --- /dev/null +++ b/gym/decision_transformer/models/mlp_bc.py @@ -0,0 +1,51 @@ +import numpy as np +import torch +import torch.nn as nn + +from decision_transformer.models.model import TrajectoryModel + + +class MLPBCModel(TrajectoryModel): + + """ + Simple MLP that predicts next action a from past states s. + """ + + def __init__(self, state_dim, act_dim, hidden_size, n_layer, dropout=0.1, max_length=1, **kwargs): + super().__init__(state_dim, act_dim) + + self.hidden_size = hidden_size + self.max_length = max_length + + layers = [nn.Linear(max_length*self.state_dim, hidden_size)] + for _ in range(n_layer-1): + layers.extend([ + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_size, hidden_size) + ]) + layers.extend([ + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_size, self.act_dim), + nn.Tanh(), + ]) + + self.model = nn.Sequential(*layers) + + def forward(self, states, actions, rewards, attention_mask=None, target_return=None): + + states = states[:,-self.max_length:].reshape(states.shape[0], -1) # concat states + actions = self.model(states).reshape(states.shape[0], 1, self.act_dim) + + return None, actions, None + + def get_action(self, states, actions, rewards, **kwargs): + states = states.reshape(1, -1, self.state_dim) + if states.shape[1] < self.max_length: + states = torch.cat( + [torch.zeros((1, self.max_length-states.shape[1], self.state_dim), + dtype=torch.float32, device=states.device), states], dim=1) + states = states.to(dtype=torch.float32) + _, actions, _ = self.forward(states, None, None, **kwargs) + return actions[0,-1] diff --git a/gym/decision_transformer/models/model.py b/gym/decision_transformer/models/model.py new file mode 100644 index 0000000..92593d7 --- /dev/null +++ b/gym/decision_transformer/models/model.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +import torch.nn as nn + + +class TrajectoryModel(nn.Module): + + def __init__(self, state_dim, act_dim, max_length=None): + super().__init__() + + self.state_dim = state_dim + self.act_dim = act_dim + self.max_length = max_length + + def forward(self, states, actions, rewards, masks=None, attention_mask=None): + # "masked" tokens or unspecified inputs can be passed in as None + return None, None, None + + def get_action(self, states, actions, rewards, **kwargs): + # these will come as tensors on the correct device + return torch.zeros_like(actions[-1]) diff --git a/gym/decision_transformer/models/trajectory_gpt2.py b/gym/decision_transformer/models/trajectory_gpt2.py new file mode 100644 index 0000000..e05b8d3 --- /dev/null +++ b/gym/decision_transformer/models/trajectory_gpt2.py @@ -0,0 +1,775 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import ( + Conv1D, + PreTrainedModel, + SequenceSummary, + find_pruneable_heads_and_indices, + prune_conv1d_layer, +) +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GPT2Config" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class Attention(nn.Module): + def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): + super().__init__() + + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implem] + assert n_state % config.n_head == 0 + self.register_buffer( + "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.is_cross_attention = is_cross_attention + if self.is_cross_attention: + self.c_attn = Conv1D(2 * n_state, nx) + self.q_attn = Conv1D(n_state, nx) + else: + self.c_attn = Conv1D(3 * n_state, nx) + self.c_proj = Conv1D(n_state, nx) + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_head, self.split_size // self.n_head, self.pruned_heads + ) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) + self.n_head = self.n_head - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): + w = torch.matmul(q, k) + if self.scale: + w = w / (float(v.size(-1)) ** 0.5) + nd, ns = w.size(-2), w.size(-1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + mask = self.bias[:, :, ns - nd: ns, :ns] + w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) + + if attention_mask is not None: + # Apply the attention mask + w = w + attention_mask + + w = nn.Softmax(dim=-1)(w) + w = self.attn_dropout(w) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [torch.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = x.permute(0, 2, 1, 3).contiguous() + new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) + return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, x, k=False): + new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) + x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states + if k: + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) + else: + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, + ): + if encoder_hidden_states is not None: + assert hasattr( + self, "q_attn" + ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self.split_heads(query) + key = self.split_heads(key, k=True) + value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + key = torch.cat((past_key, key), dim=-1) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + else: + present = (None,) + + attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a) + + outputs = [a, present] + attn_outputs[1:] + return outputs # a, present, (attentions) + + +class MLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super().__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class AdapterMLP(nn.Module): + def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) + super().__init__() + nx = config.n_embd + self.c_fc = Conv1D(n_state, nx) + self.c_proj = Conv1D(nx, n_state) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, x): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + return self.dropout(h2) + + +class Block(nn.Module): + def __init__(self, n_ctx, config, scale=False): + super().__init__() + hidden_size = config.n_embd + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = Attention(hidden_size, n_ctx, config, scale) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if config.add_cross_attention: + self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = MLP(inner_dim, config) + # self.adapter_mlp = AdapterMLP(512, config) # ADAPTER + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, + ): + attn_outputs = self.attn( + self.ln_1(hidden_states), + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + hidden_states + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + assert hasattr( + self, "crossattention" + ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + cross_attn_outputs = self.crossattention( + self.ln_cross_attn(hidden_states), + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = hidden_states + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + # residual connection + hidden_states = hidden_states + feed_forward_hidden_states + # hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states)) + + outputs = [hidden_states] + outputs + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + # module.weight.data.fill_(.01) # KL: Adapter change + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): + Language modeling loss. + mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): + Multiple choice classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, + batch_size, num_heads, sequence_length, embed_size_per_head)`). + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): + :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else + ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input + sequence tokens in the vocabulary. + If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be + passed as ``input_ids``. + Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which + have their past given to this model should not be passed as ``input_ids`` as they have already been + computed. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see + :obj:`past_key_values`). + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + Example:: + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained('gpt2-xl') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + Example:: + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained('gpt2-large') + device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wte = nn.Embedding(config.vocab_size, config.n_embd) + # self.wpe = nn.Embedding(config.n_positions, config.n_embd) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + self.use_layers = None + + def set_layers(self, num_layers): + assert 1 <= num_layers <= len(self.h) + if num_layers is not None: + num_layers -= 1 + self.use_layers = num_layers + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="gpt2", + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = [None] * len(self.h) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + # position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds # + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if self.use_layers is not None and i >= self.use_layers: + break + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = layer_past.to(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + # checkpointing only works with tuple returns, not with lists + return tuple(output for output in module(*inputs, use_cache, output_attentions)) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states, present = outputs[:2] + if use_cache is True: + presents = presents + (present,) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) diff --git a/gym/decision_transformer/training/act_trainer.py b/gym/decision_transformer/training/act_trainer.py new file mode 100644 index 0000000..d18ba0f --- /dev/null +++ b/gym/decision_transformer/training/act_trainer.py @@ -0,0 +1,29 @@ +import numpy as np +import torch + +from decision_transformer.training.trainer import Trainer + + +class ActTrainer(Trainer): + + def train_step(self): + states, actions, rewards, dones, rtg, _, attention_mask = self.get_batch(self.batch_size) + state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) + + state_preds, action_preds, reward_preds = self.model.forward( + states, actions, rewards, attention_mask=attention_mask, target_return=rtg[:,0], + ) + + act_dim = action_preds.shape[2] + action_preds = action_preds.reshape(-1, act_dim) + action_target = action_target[:,-1].reshape(-1, act_dim) + + loss = self.loss_fn( + state_preds, action_preds, reward_preds, + state_target, action_target, reward_target, + ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return loss.detach().cpu().item() diff --git a/gym/decision_transformer/training/seq_trainer.py b/gym/decision_transformer/training/seq_trainer.py new file mode 100644 index 0000000..8c4ffe3 --- /dev/null +++ b/gym/decision_transformer/training/seq_trainer.py @@ -0,0 +1,57 @@ +import numpy as np +import torch + +from decision_transformer.training.trainer import Trainer + + +class SequenceTrainer(Trainer): + + def train_step(self): + states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size) + action_target = torch.clone(actions) + state_target = torch.clone(states) + rtg_target = torch.clone(rtg[:,:-1]) + + state_preds, action_preds, return_preds, action_log_probs, entropies = self.model.forward( + states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,target_actions=action_target + ) + + act_dim = action_preds.shape[2] + state_dim = state_preds.shape[2] + action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] + action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0] + + if action_log_probs != None: + action_log_probs = action_log_probs.reshape(-1)[attention_mask.reshape(-1) > 0] + if entropies != None: + entropies = entropies.reshape(-1)[attention_mask.reshape(-1) > 0] + + loss = self.loss_fn( + state_preds, action_preds, return_preds, None, + state_target, action_target, rtg_target, None, + action_log_probs, entropies + ) + + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) + self.optimizer.step() + + + # Entropy multiplier tuning + if self.log_entropy_multiplier is not None: + entropy_multiplier_loss = self.entropy_loss_fn(entropies) + self.multiplier_optimizer.zero_grad() + entropy_multiplier_loss.backward() + self.multiplier_optimizer.step() + + entropy_loss = entropy_multiplier_loss.detach().cpu().item() + else: + entropy_loss = None + + with torch.no_grad(): + self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item() + if self.log_entropy_multiplier is not None: + self.diagnostics['training/entropy_multiplier'] = torch.exp(self.log_entropy_multiplier).detach().cpu().item() + self.diagnostics['training/entropy'] = torch.mean(entropies).item() + return loss.detach().cpu().item(), entropy_loss diff --git a/gym/decision_transformer/training/trainer.py b/gym/decision_transformer/training/trainer.py new file mode 100644 index 0000000..1c2d316 --- /dev/null +++ b/gym/decision_transformer/training/trainer.py @@ -0,0 +1,94 @@ +import numpy as np +import torch + +import time + + +class Trainer: + + def __init__(self, model, optimizer, batch_size, get_batch, loss_fn, scheduler=None, eval_fns=None, log_entropy_multiplier = None, multiplier_optimizer = None, multiplier_scheduler = None, entropy_loss_fn=None): + self.model = model + self.optimizer = optimizer + self.batch_size = batch_size + self.get_batch = get_batch + self.loss_fn = loss_fn + self.scheduler = scheduler + # Optional entropy multiplier and its loss, optimizer, scheduler + self.log_entropy_multiplier = log_entropy_multiplier + self.entropy_loss_fn = entropy_loss_fn + self.multiplier_optimizer = multiplier_optimizer + self.multiplier_scheduler = multiplier_scheduler + + self.eval_fns = [] if eval_fns is None else eval_fns + self.diagnostics = dict() + + + self.start_time = time.time() + + def train_iteration(self, num_steps, iter_num=0, print_logs=False): + + train_losses = [] + entropy_losses = [] + logs = dict() + + train_start = time.time() + + self.model.train() + for _ in range(num_steps): + train_loss, entropy_loss = self.train_step() + if entropy_loss is not None: + entropy_losses.append(entropy_loss) + train_losses.append(train_loss) + if self.scheduler is not None: + self.scheduler.step() + if self.multiplier_scheduler is not None: + self.multiplier_scheduler.step() + + logs['time/training'] = time.time() - train_start + + eval_start = time.time() + + self.model.eval() + for eval_fn in self.eval_fns: + outputs = eval_fn(self.model) + for k, v in outputs.items(): + logs[f'evaluation/{k}'] = v + + logs['time/total'] = time.time() - self.start_time + logs['time/evaluation'] = time.time() - eval_start + logs['training/train_loss_mean'] = np.mean(train_losses) + logs['training/train_loss_std'] = np.std(train_losses) + if self.log_entropy_multiplier is not None: + logs['training/entropy_multiplier_loss_mean'] = np.mean(entropy_losses) + logs['training/entropy_multiplier__loss_std'] = np.std(entropy_losses) + + + for k in self.diagnostics: + logs[k] = self.diagnostics[k] + + if print_logs: + print('=' * 80) + print(f'Iteration {iter_num}') + for k, v in logs.items(): + print(f'{k}: {v}') + + return logs + + def train_step(self): + states, actions, rewards, dones, attention_mask, returns = self.get_batch(self.batch_size) + state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards) + + state_preds, action_preds, reward_preds = self.model.forward( + states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns, + ) + + # note: currently indexing & masking is not fully correct + loss = self.loss_fn( + state_preds, action_preds, reward_preds, + state_target[:,1:], action_target, reward_target[:,1:], + ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return loss.detach().cpu().item() diff --git a/gym/experiment.py b/gym/experiment.py new file mode 100644 index 0000000..0ea4c19 --- /dev/null +++ b/gym/experiment.py @@ -0,0 +1,465 @@ +import gym +import numpy as np +import torch +import wandb + +import argparse +import pickle +import random +import sys +import os +import pathlib + +from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg +from decision_transformer.models.decision_transformer import DecisionTransformer +from decision_transformer.models.mlp_bc import MLPBCModel +from decision_transformer.training.act_trainer import ActTrainer +from decision_transformer.training.seq_trainer import SequenceTrainer + + +def discount_cumsum(x, gamma): + discount_cumsum = np.zeros_like(x) + discount_cumsum[-1] = x[-1] + for t in reversed(range(x.shape[0]-1)): + discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] + return discount_cumsum + + +def experiment( + exp_prefix, + variant, +): + device = variant.get('device', 'cuda') + log_to_wandb = variant.get('log_to_wandb', False) + + + env_name, dataset = variant['env'], variant['dataset'] + model_type = variant['model_type'] + group_name = f'{exp_prefix}-{env_name}-{dataset}' + exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}' + + model_dir = os.path.join(pathlib.Path(__file__).parent.resolve(),f'./models/{env_name}/') + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + if env_name == 'hopper': + env = gym.make('Hopper-v3') + max_ep_len = 1000 + env_targets = [3600, 1800] # evaluation conditioning targets + scale = 1000. # normalization for rewards/returns + elif env_name == 'halfcheetah': + env = gym.make('HalfCheetah-v3') + max_ep_len = 1000 + env_targets = [12000, 6000] + scale = 1000. + elif env_name == 'walker2d': + env = gym.make('Walker2d-v3') + max_ep_len = 1000 + env_targets = [5000, 2500] + scale = 1000. + elif env_name == 'reacher2d': + from decision_transformer.envs.reacher_2d import Reacher2dEnv + env = Reacher2dEnv() + max_ep_len = 100 + env_targets = [76, 40] + scale = 10. + else: + raise NotImplementedError + + # Override env_targets / set different training target for online decision transformer, following paper + if variant['online_training']: + if env_name == 'hopper': + env_targets = [3600] # evaluation conditioning targets + target_online = 7200 + elif env_name == 'halfcheetah': + env_targets = [6000] + target_online = 12000 + elif env_name == 'walker2d': + env_targets = [5000] + target_online = 10000 + else: + raise NotImplementedError + + if model_type == 'bc': + env_targets = env_targets[:1] # since BC ignores target, no need for different evaluations + + state_dim = env.observation_space.shape[0] + act_dim = env.action_space.shape[0] + + # load dataset + dataset_path = f'data/{env_name}-{dataset}-v2.pkl' + with open(dataset_path, 'rb') as f: + trajectories = pickle.load(f) + + # save all path information into separate lists + mode = variant.get('mode', 'normal') + states, traj_lens, returns = [], [], [] + for path in trajectories: + if mode == 'delayed': # delayed: all rewards moved to end of trajectory + path['rewards'][-1] = path['rewards'].sum() + path['rewards'][:-1] = 0. + states.append(path['observations']) + traj_lens.append(len(path['observations'])) + returns.append(path['rewards'].sum()) + traj_lens, returns = np.array(traj_lens), np.array(returns) + + # used for input normalization + states = np.concatenate(states, axis=0) + state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + num_timesteps = sum(traj_lens) + + print('=' * 50) + print(f'Starting new experiment: {env_name} {dataset}') + print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found') + print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}') + print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}') + print('=' * 50) + + K = variant['K'] + batch_size = variant['batch_size'] + num_eval_episodes = variant['num_eval_episodes'] + pct_traj = variant.get('pct_traj', 1.) + + # only train on top pct_traj trajectories (for %BC experiment) + num_timesteps = max(int(pct_traj*num_timesteps), 1) + sorted_inds = np.argsort(returns) # lowest to highest + num_trajectories = 1 + timesteps = traj_lens[sorted_inds[-1]] + ind = len(trajectories) - 2 + while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps: + timesteps += traj_lens[sorted_inds[ind]] + num_trajectories += 1 + ind -= 1 + sorted_inds = sorted_inds[-num_trajectories:] + + # used to reweight sampling so we sample according to timesteps instead of trajectories + p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) + + # Sort trajectories from worst to best and cut to buffer size + if variant['online_training']: + trajectories = [trajectories[index] for index in sorted_inds] + trajectories = trajectories[:variant['online_buffer_size']] + num_trajectories = len(trajectories) + + starting_p_sample = p_sample + def get_batch(batch_size=256, max_len=K): + # Dynamically recompute p_sample if online training + if variant['online_training']: + traj_lens = np.array([len(path['observations']) for path in trajectories]) + p_sample = traj_lens / sum(traj_lens) + else: + p_sample = starting_p_sample + + + batch_inds = np.random.choice( + np.arange(num_trajectories), + size=batch_size, + replace=True, + p=p_sample, # reweights so we sample according to timesteps + ) + + s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [] + for i in range(batch_size): + if variant['online_training']: + traj = trajectories[batch_inds[i]] + else: + traj = trajectories[int(sorted_inds[batch_inds[i]])] + si = random.randint(0, traj['rewards'].shape[0] - 1) + + # get sequences from dataset + s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim)) + a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim)) + r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1)) + if 'terminals' in traj: + d.append(traj['terminals'][si:si + max_len].reshape(1, -1)) + else: + d.append(traj['dones'][si:si + max_len].reshape(1, -1)) + timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) + timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1 # padding cutoff + rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1)) + if rtg[-1].shape[1] <= s[-1].shape[1]: + rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1) + + # padding and state + reward normalization + tlen = s[-1].shape[1] + s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1) + s[-1] = (s[-1] - state_mean) / state_std + a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * 0., a[-1]], axis=1) + r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1) + d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1) + rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale + timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1) + mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)) + + s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device) + a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device) + r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device) + d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device) + rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device) + timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device) + mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device) + + return s, a, r, d, rtg, timesteps, mask + + if variant['online_training']: + # If online training, use means during eval, but (not during exploration) + variant['use_action_means'] = True + + def eval_episodes(target_rew): + def fn(model): + returns, lengths = [], [] + for _ in range(num_eval_episodes): + with torch.no_grad(): + if model_type == 'dt': + ret, length = evaluate_episode_rtg( + env, + state_dim, + act_dim, + model, + max_ep_len=max_ep_len, + scale=scale, + target_return=target_rew/scale, + mode=mode, + state_mean=state_mean, + state_std=state_std, + device=device, + use_means=variant['use_action_means'], + eval_context=variant['eval_context'] + ) + else: + ret, length = evaluate_episode( + env, + state_dim, + act_dim, + model, + max_ep_len=max_ep_len, + target_return=target_rew/scale, + mode=mode, + state_mean=state_mean, + state_std=state_std, + device=device, + ) + returns.append(ret) + lengths.append(length) + return { + f'target_{target_rew}_return_mean': np.mean(returns), + f'target_{target_rew}_return_std': np.std(returns), + f'target_{target_rew}_length_mean': np.mean(lengths), + f'target_{target_rew}_length_std': np.std(lengths), + } + return fn + + + if model_type == 'dt': + if variant['pretrained_model']: + model = torch.load(variant['pretrained_model'],map_location='cuda:0') + model.stochastic_tanh = variant['stochastic_tanh'] + model.approximate_entropy_samples = variant['approximate_entropy_samples'] + model.to(device) + + else: + model = DecisionTransformer( + state_dim=state_dim, + act_dim=act_dim, + max_length=K, + max_ep_len=max_ep_len*2, + hidden_size=variant['embed_dim'], + n_layer=variant['n_layer'], + n_head=variant['n_head'], + n_inner=4*variant['embed_dim'], + activation_function=variant['activation_function'], + n_positions=1024, + resid_pdrop=variant['dropout'], + attn_pdrop=variant['dropout'], + stochastic = variant['stochastic'], + remove_pos_embs=variant['remove_pos_embs'], + approximate_entropy_samples = variant['approximate_entropy_samples'], + stochastic_tanh=variant['stochastic_tanh'] + ) + elif model_type == 'bc': + model = MLPBCModel( + state_dim=state_dim, + act_dim=act_dim, + max_length=K, + hidden_size=variant['embed_dim'], + n_layer=variant['n_layer'], + ) + else: + raise NotImplementedError + + model = model.to(device=device) + warmup_steps = variant['warmup_steps'] + optimizer = torch.optim.AdamW( + model.parameters(), + lr=variant['learning_rate'], + weight_decay=variant['weight_decay'], + ) + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lambda steps: min((steps+1)/warmup_steps, 1) + ) + + if variant['online_training']: + assert(variant['pretrained_model'] is not None), "Must specify pretrained model to perform online finetuning" + variant['use_entropy'] = True + + if variant['online_training'] and variant['target_entropy']: + # Setup variable and optimizer for (log of) lagrangian multiplier used for entropy constraint + # We optimize the log of the multiplier b/c lambda >= 0 + log_entropy_multiplier = torch.zeros(1, requires_grad=True, device=device) + multiplier_optimizer = torch.optim.AdamW( + [log_entropy_multiplier], + lr=variant['learning_rate'], + weight_decay=variant['weight_decay'], + ) + # multiplier_optimizer = torch.optim.Adam( + # [log_entropy_multiplier], + # lr=1e-3 + # #lr=variant['learning_rate'], + # ) + multiplier_scheduler = torch.optim.lr_scheduler.LambdaLR( + multiplier_optimizer, + lambda steps: min((steps+1)/warmup_steps, 1) + ) + else: + log_entropy_multiplier = None + multiplier_optimizer = None + multiplier_scheduler = None + + entropy_loss_fn = None + if variant['stochastic']: + if variant['use_entropy']: + if variant['target_entropy']: + loss_fn = lambda s_hat, a_hat, rtg_hat,r_hat, s, a, rtg, r, a_log_prob, entropies: -torch.mean(a_log_prob) - torch.exp(log_entropy_multiplier.detach()) * torch.mean(entropies) + target_entropy = -act_dim + entropy_loss_fn = lambda entropies: torch.exp(log_entropy_multiplier) * (torch.mean(entropies.detach()) - target_entropy) + else: + loss_fn = lambda s_hat, a_hat, rtg_hat,r_hat, s, a, rtg, r, a_log_prob, entropies: -torch.mean(a_log_prob) - torch.mean(entropies) + else: + loss_fn = lambda s_hat, a_hat, rtg_hat, r_hat, s, a, rtg,r, a_log_prob, entropies: -torch.mean(a_log_prob) + else: + loss_fn = lambda s_hat, a_hat, rtg_hat, r_hat, s, a, rtg, r, a_log_prob, entropies: torch.mean((a_hat - a)**2) + + if model_type == 'dt': + trainer = SequenceTrainer( + model=model, + optimizer=optimizer, + batch_size=batch_size, + get_batch=get_batch, + scheduler=scheduler, + loss_fn=loss_fn, + log_entropy_multiplier=log_entropy_multiplier, + entropy_loss_fn=entropy_loss_fn, + multiplier_optimizer=multiplier_optimizer, + multiplier_scheduler=multiplier_scheduler, + eval_fns=[eval_episodes(tar) for tar in env_targets], + ) + elif model_type == 'bc': + trainer = ActTrainer( + model=model, + optimizer=optimizer, + batch_size=batch_size, + get_batch=get_batch, + scheduler=scheduler, + loss_fn=loss_fn, + eval_fns=[eval_episodes(tar) for tar in env_targets], + ) + + if log_to_wandb: + wandb.init( + name=exp_prefix, + group=group_name, + project='decision-transformer', + config=variant + ) + # wandb.watch(model) # wandb has some bug + if variant['eval_only']: + model.eval() + eval_fns = [eval_episodes(tar) for tar in env_targets] + + for iter_num in range(variant['max_iters']): + logs = {} + for eval_fn in eval_fns: + outputs = eval_fn(model) + for k, v in outputs.items(): + logs[f'evaluation/{k}'] = v + + print('=' * 80) + print(f'Iteration {iter_num}') + for k, v in logs.items(): + print(f'{k}: {v}') + else: + if variant['online_training']: + for iter in range(variant['max_iters']): + # Collect new rollout, using stochastic policy + ret, length, traj = evaluate_episode_rtg( + env, + state_dim, + act_dim, + model, + max_ep_len=max_ep_len, + scale=scale, + target_return=target_online/scale, + mode=mode, + state_mean=state_mean, + state_std=state_std, + device=device, + use_means=False, + return_traj=True + ) + # Remove oldest trajectory, add new trajectory + trajectories = trajectories[1:] + trajectories.append(traj) + + # Perform update, eval using deterministic policy + outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) + if log_to_wandb: + wandb.log(outputs) + else: + for iter in range(variant['max_iters']): + outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) + if log_to_wandb: + wandb.log(outputs) + + torch.save(model,os.path.join(model_dir, model_type + '_' + exp_prefix + '.pt')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--env', type=str, default='hopper') + parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert + parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse + parser.add_argument('--K', type=int, default=20) + parser.add_argument('--pct_traj', type=float, default=1.) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--model_type', type=str, default='dt') # dt for decision transformer, bc for behavior cloning + parser.add_argument('--embed_dim', type=int, default=128) + parser.add_argument('--n_layer', type=int, default=3) + parser.add_argument('--n_head', type=int, default=1) + parser.add_argument('--activation_function', type=str, default='relu') + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4) + parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4) + parser.add_argument('--warmup_steps', type=int, default=10000) + parser.add_argument('--num_eval_episodes', type=int, default=100) + parser.add_argument('--max_iters', type=int, default=10) + parser.add_argument('--num_steps_per_iter', type=int, default=10000) + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--log_to_wandb', '-w', type=bool, default=False) + parser.add_argument('--save_model', default=False, action='store_true') + parser.add_argument('--pretrained_model', default=None, type=str) + parser.add_argument('--stochastic', default=False, action='store_true') + parser.add_argument('--use_entropy', default=False, action='store_true') + parser.add_argument('--use_action_means', default=False, action='store_true') + parser.add_argument('--online_training', default=False, action='store_true') + parser.add_argument('--online_buffer_size', default=1000, type=int) # keep top N trajectories for online training in replay buffer to start + parser.add_argument('--eval_only', default=False, action='store_true') + parser.add_argument('--remove_pos_embs', default=False, action='store_true') + parser.add_argument('--eval_context', default=None, type=int) + parser.add_argument('--target_entropy', default=False, action='store_true') + parser.add_argument('--stochastic_tanh', default=False, action='store_true') + parser.add_argument('--approximate_entropy_samples',default=1000, type=int, help="if using stochastic network w/ tanh squashing, have to approximate entropy with k samples, as no anlytical solution") + args = parser.parse_args() + + experiment('gym-experiment', variant=vars(args)) diff --git a/gym/models/halfcheetah/dt_gym-experiment-halfcheetah-medium-241050.pt b/gym/models/halfcheetah/dt_gym-experiment-halfcheetah-medium-241050.pt new file mode 100644 index 0000000..772b22e Binary files /dev/null and b/gym/models/halfcheetah/dt_gym-experiment-halfcheetah-medium-241050.pt differ diff --git a/gym/readme-gym.md b/gym/readme-gym.md new file mode 100644 index 0000000..33ae9b1 --- /dev/null +++ b/gym/readme-gym.md @@ -0,0 +1,33 @@ + +# OpenAI Gym +Modifying + +## Installation + +Experiments require MuJoCo. +Follow the instructions in the [mujoco-py repo](https://github.com/openai/mujoco-py) to install. +Then, dependencies can be installed with the following command: + +``` +conda env create -f conda_env.yml +``` + +## Downloading datasets + +Datasets are stored in the `data` directory. +Install the [D4RL repo](https://github.com/rail-berkeley/d4rl), following the instructions there. +Then, run the following script in order to download the datasets and save them in our format: + +``` +python download_d4rl_datasets.py +``` + +## Example usage + +Experiments can be reproduced with the following: + +``` +python experiment.py --env hopper --dataset medium --model_type dt +``` + +Adding `-w True` will log results to Weights and Biases. diff --git a/run_docker.sh b/run_docker.sh new file mode 100644 index 0000000..3d7dcec --- /dev/null +++ b/run_docker.sh @@ -0,0 +1 @@ +docker run -it --mount "type=bind,source=$(pwd),target=/app/dt" --entrypoint /bin/bash --gpus=all dt_experiments diff --git a/runs.md b/runs.md new file mode 100644 index 0000000..a7fa8ea --- /dev/null +++ b/runs.md @@ -0,0 +1,27 @@ +TODO, update +Commands: + +Hopper: +Pretraining: + +python experiment.py --env hopper --dataset medium --model_type dt --num_eval_episodes=50 --max_iters=5 --num_steps_per_iter=1000 --stochastic --use_action_means --learning_rate=1e-4 --embed_dim=512 --weight_decay=5e-4 --K=20 --remove_pos_embs --n_layer=4 --n_head=4 --batch_size=256 --eval_context=5 --device=cuda:2 --log_to_wandb=True --stochastic_tanh + +Online finetuning: +python experiment.py --env hopper --dataset medium --model_type dt --pretrained_model=./models/hopper/dt_gym-experiment-hopper-medium-506105.pt --stochastic --use_action_means --online_training --eval_context=5 --K=20 --batch_size=256 --num_steps_per_iter=300 --max_iters=200 --num_eval_episodes=50 --stochastic_tanh --device=cuda:2 --log_to_wandb=True + +python experiment.py --env hopper --dataset medium --model_type dt --pretrained_model=./models/hopper/dt_gym-experiment-hopper-medium-506105.pt --stochastic --use_action_means --online_training --eval_context=5 --K=20 --batch_size=256 --num_steps_per_iter=300 --max_iters=200 --num_eval_episodes=50 --device=cuda:2 --target_entropy --log_to_wandb=True --stochastic_tanh + + +Walker2D: +#Fix, this is wrong +pretraining: +python experiment.py --env walker2d --dataset medium --model_type dt --num_eval_episodes=50 --max_iters=5 --num_steps_per_iter=2000 --stochastic --use_action_means --learning_rate=1e-3 --embed_dim=512 --weight_decay=1e-3 --K=20 --remove_pos_embs --n_layer=4 --n_head=4 --batch_size=256 --eval_context=5 --stochastic_tanh --device=cuda:2 --log_to_wandb=True + + + +python experiment.py --env walker2d --dataset medium --model_type dt --pretrained_model=./models/walker2d/dt_gym-experiment-walker2d-medium-763104.pt --stochastic --use_action_means --online_training --eval_context=5 --K=20 --batch_size=256 --num_steps_per_iter=300 --max_iters=200 --num_eval_episodes=50 --learning_rate=1e-3 --weight_decay=1e-3 --device=cuda:2 --log_to_wandb=True --target_entropy --stochastic_tanh + + + +Model-based testing: + python experiment.py --env halfcheetah --dataset medium --model_type dt --num_eval_episodes=10 --max_iters=1 --num_steps_per_iter=0 --stochastic --device=cuda:1 --use_model --pretrained_model=./models/halfcheetah/dt_gym-experiment-halfcheetah-medium-268755.pt --pretrained_mode=static --use_action_means --plan_horizon=25 --number_rollouts=10 \ No newline at end of file