Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minigrid doorkey environment #251

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions envpool/minigrid/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ package(default_visibility = ["//visibility:public"])
cc_library(
name = "minigrid_env",
srcs = [
"impl/minigrid_doorkey_env.cc",
"impl/minigrid_empty_env.cc",
"impl/minigrid_env.cc",
],
hdrs = [
"doorkey.h",
"empty.h",
"impl/minigrid_doorkey_env.h",
"impl/minigrid_empty_env.h",
"impl/minigrid_env.h",
"impl/utils.h",
Expand Down
92 changes: 92 additions & 0 deletions envpool/minigrid/doorkey.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2023 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ENVPOOL_MINIGRID_DOORKEY_H_
#define ENVPOOL_MINIGRID_DOORKEY_H_

#include <utility>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"
#include "envpool/minigrid/impl/minigrid_doorkey_env.h"
#include "envpool/minigrid/impl/minigrid_env.h"

namespace minigrid {

class DoorKeyEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("size"_.Bind(8),
"agent_start_pos"_.Bind(std::pair<int, int>(-1, -1)),
"agent_start_dir"_.Bind(-1), "agent_view_size"_.Bind(7));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
int agent_view_size = conf["agent_view_size"_];
int size = conf["size"_];
return MakeDict("obs:direction"_.Bind(Spec<int>({-1}, {0, 3})),
"obs:image"_.Bind(Spec<uint8_t>(
{agent_view_size, agent_view_size, 3}, {0, 255})),
"info:agent_pos"_.Bind(Spec<int>({2}, {0, size})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 6})));
}
};

using DoorKeyEnvSpec = EnvSpec<DoorKeyEnvFns>;
using FrameSpec = Spec<uint8_t>;

class DoorKeyEnv : public Env<DoorKeyEnvSpec>, public MiniGridDoorKeyEnv {
public:
DoorKeyEnv(const Spec& spec, int env_id)
: Env<DoorKeyEnvSpec>(spec, env_id),
MiniGridDoorKeyEnv(
spec.config["size"_], spec.config["agent_start_pos"_],
spec.config["agent_start_dir"_], spec.config["max_episode_steps"_],
spec.config["agent_view_size"_]) {
gen_ref_ = &gen_;
}

bool IsDone() override { return done_; }

void Reset() override {
MiniGridReset();
WriteState(0.0);
}

void Step(const Action& action) override {
int act = action["action"_];
WriteState(MiniGridStep(static_cast<Act>(act)));
}

private:
void WriteState(float reward) {
State state = Allocate();
GenImage(state["obs:image"_]);
state["obs:direction"_] = agent_dir_;
state["reward"_] = reward;
state["info:agent_pos"_](0) = agent_pos_.first;
state["info:agent_pos"_](1) = agent_pos_.second;
}
};

using DoorKeyEnvPool = AsyncEnvPool<DoorKeyEnv>;

} // namespace minigrid

#endif // ENVPOOL_MINIGRID_DOORKEY_H_
75 changes: 75 additions & 0 deletions envpool/minigrid/impl/minigrid_doorkey_env.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2023 Garena Online Private Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "envpool/minigrid/impl/minigrid_doorkey_env.h"

#include <utility>
#include <vector>

namespace minigrid {

MiniGridDoorKeyEnv::MiniGridDoorKeyEnv(int size,
std::pair<int, int> agent_start_pos,
int agent_start_dir, int max_steps,
int agent_view_size) {
width_ = size;
height_ = size;
agent_start_pos_ = agent_start_pos;
agent_start_dir_ = agent_start_dir;
see_through_walls_ = false;
max_steps_ = max_steps;
agent_view_size_ = agent_view_size;
}

void MiniGridDoorKeyEnv::GenGrid() {
grid_.clear();
for (int i = 0; i < height_; ++i) {
std::vector<WorldObj> temp_vec(width_);
for (int j = 0; j < width_; ++j) {
temp_vec[j] = WorldObj(kEmpty);
}
grid_.emplace_back(temp_vec);
}
// generate the surrounding walls
for (int i = 0; i < width_; ++i) {
grid_[0][i] = WorldObj(kWall, kGrey);
grid_[height_ - 1][i] = WorldObj(kWall, kGrey);
}
for (int i = 0; i < height_; ++i) {
grid_[i][0] = WorldObj(kWall, kGrey);
grid_[i][width_ - 1] = WorldObj(kWall, kGrey);
}
// place a goal square in the bottom-right corner
grid_[height_ - 2][width_ - 2] = WorldObj(kGoal, kGreen);
// generate a vertical splitting wall
std::uniform_int_distribution<> x_dist(2, width_ - 3);
int x = x_dist(*gen_ref_);
for (int y = 0; y < height_; ++y) {
grid_[y][x] = WorldObj(kWall, kGrey);
}
// place the agent at a random position and orientation
// on the left side of the splitting wall
PlaceAgent(1, 1, x - 1, height_ - 2);
// place a door in the wall
std::uniform_int_distribution<> y_dist(1, height_ - 3);
int door_idx = y_dist(*gen_ref_);
grid_[door_idx][x] = WorldObj(kDoor, kYellow);
grid_[door_idx][x].SetDoorLocked(true);
grid_[door_idx][x].SetDoorOpen(false);
// place a yellow key on the left side
auto pos = PlaceObject(1, 1, x - 1, height_ - 2);
grid_[pos.second][pos.first] = WorldObj(kKey, kYellow);
}

} // namespace minigrid
35 changes: 35 additions & 0 deletions envpool/minigrid/impl/minigrid_doorkey_env.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2023 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_
#define ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_

#include <utility>

#include "envpool/minigrid/impl/minigrid_env.h"

namespace minigrid {

class MiniGridDoorKeyEnv : public MiniGridEnv {
public:
MiniGridDoorKeyEnv(int size, std::pair<int, int> agent_start_pos,
int agent_start_dir, int max_steps, int agent_view_size);
void GenGrid() override;
};

} // namespace minigrid

#endif // ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_
70 changes: 62 additions & 8 deletions envpool/minigrid/impl/minigrid_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,41 @@ void MiniGridEnv::PlaceAgent(int start_x, int start_y, int end_x, int end_y) {
end_x = (end_x == -1) ? width_ - 1 : end_x;
end_y = (end_y == -1) ? height_ - 1 : end_y;
CHECK(start_x <= end_x && start_y <= end_y);
agent_pos_.first = -1;
agent_pos_.second = -1;
auto pos = PlaceObject(start_x, start_y, end_x, end_y);
agent_pos_.first = pos.first;
agent_pos_.second = pos.second;
// Randomly select a direction
if (agent_start_dir_ == -1) {
std::uniform_int_distribution<> dir_dist(0, 3);
agent_dir_ = dir_dist(*gen_ref_);
}
}

// place an object where x-index in [start_x, end_x] and y-index in [start_y,
// end_y] return the desired position (x, y)
std::pair<int, int> MiniGridEnv::PlaceObject(int start_x, int start_y,
int end_x, int end_y) {
std::pair<int, int> result;
std::uniform_int_distribution<> x_dist(start_x, end_x);
std::uniform_int_distribution<> y_dist(start_y, end_y);
while (true) {
int x = x_dist(*gen_ref_);
int y = y_dist(*gen_ref_);
// don't place the objwct on top of another object
if (grid_[y][x].GetType() != kEmpty) {
continue;
}
agent_pos_.first = x;
agent_pos_.second = y;
// don't place the object where the agent is
if (agent_pos_.first == x && agent_pos_.second == y) {
continue;
}
result.first = x;
result.second = y;
break;
}
// Randomly select a direction
if (agent_start_dir_ == -1) {
std::uniform_int_distribution<> dir_dist(0, 3);
agent_dir_ = dir_dist(*gen_ref_);
}
return result;
}

void MiniGridEnv::GenImage(const Array& obs) {
Expand Down Expand Up @@ -213,8 +231,44 @@ void MiniGridEnv::GenImage(const Array& obs) {
std::fill(row.begin(), row.end(), 0);
}
if (!see_through_walls_) {
// TODO(siping): Process_vis
vis_mask[agent_pos_y][agent_pos_x] = true;
for (int j = agent_view_size_ - 1; j >= 0; --j) {
// left -> right
for (int i = 0; i <= agent_view_size_ - 2; ++i) {
if (!vis_mask[j][i]) {
continue;
}
if (!agent_view_grid[j][i].CanSeeBehind()) {
continue;
}
vis_mask[j][i + 1] = true;
if (j > 0) {
vis_mask[j - 1][i + 1] = true;
vis_mask[j - 1][i] = true;
}
}
// right -> left
for (int i = agent_view_size_ - 1; i >= 1; --i) {
if (!vis_mask[j][i]) {
continue;
}
if (!agent_view_grid[j][i].CanSeeBehind()) {
continue;
}
vis_mask[j][i - 1] = true;
if (j > 0) {
vis_mask[j - 1][i - 1] = true;
vis_mask[j - 1][i] = true;
}
}
}
for (int j = 0; j < agent_view_size_; ++j) {
for (int i = 0; i < agent_view_size_; ++i) {
if (!vis_mask[j][i]) {
agent_view_grid[j][i] = WorldObj(kEmpty);
}
}
}
} else {
for (auto& row : vis_mask) {
std::fill(row.begin(), row.end(), 1);
Expand Down
2 changes: 2 additions & 0 deletions envpool/minigrid/impl/minigrid_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class MiniGridEnv {
float MiniGridStep(Act act);
void PlaceAgent(int start_x = 0, int start_y = 0, int end_x = -1,
int end_y = -1);
std::pair<int, int> PlaceObject(int start_x, int start_y, int end_x,
int end_y);
void GenImage(const Array& obs);
virtual void GenGrid() {}
};
Expand Down
2 changes: 1 addition & 1 deletion envpool/minigrid/impl/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class WorldObj {
[[nodiscard]] bool GetDoorOpen() const { return door_open_; }
void SetDoorOpen(bool flag) { door_open_ = flag; }
[[nodiscard]] bool GetDoorLocked() const { return door_locked_; }
void SetDoorLocker(bool flag) { door_locked_ = flag; }
void SetDoorLocked(bool flag) { door_locked_ = flag; }
Type GetType() { return type_; }
Color GetColor() { return color_; }
int GetState() {
Expand Down
9 changes: 8 additions & 1 deletion envpool/minigrid/minigrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
// limitations under the License.

#include "envpool/core/py_envpool.h"
#include "envpool/minigrid/doorkey.h"
#include "envpool/minigrid/empty.h"

using EmptyEnvSpec = PyEnvSpec<minigrid::EmptyEnvSpec>;
using EmptyEnvPool = PyEnvPool<minigrid::EmptyEnvPool>;

PYBIND11_MODULE(minigrid_envpool, m) { REGISTER(m, EmptyEnvSpec, EmptyEnvPool) }
using DoorKeyEnvSpec = PyEnvSpec<minigrid::DoorKeyEnvSpec>;
using DoorKeyEnvPool = PyEnvPool<minigrid::DoorKeyEnvPool>;

PYBIND11_MODULE(minigrid_envpool, m) {
REGISTER(m, EmptyEnvSpec, EmptyEnvPool)
REGISTER(m, DoorKeyEnvSpec, DoorKeyEnvPool)
}
11 changes: 11 additions & 0 deletions envpool/minigrid/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,14 @@
max_episode_steps=1024,
size=16,
)

register(
task_id="MiniGrid-DoorKey-5x5-v0",
import_path="envpool.minigrid",
spec_cls="DoorKeyEnvSpec",
dm_cls="DoorKeyDMEnvPool",
gym_cls="DoorKeyGymEnvPool",
gymnasium_cls="DoorKeyGymnasiumEnvPool",
max_episode_steps=250,
size=5,
)