diff --git a/envpool/minigrid/BUILD b/envpool/minigrid/BUILD index 1ae60e13..d7d2795f 100644 --- a/envpool/minigrid/BUILD +++ b/envpool/minigrid/BUILD @@ -20,11 +20,14 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "minigrid_env", srcs = [ + "impl/minigrid_doorkey_env.cc", "impl/minigrid_empty_env.cc", "impl/minigrid_env.cc", ], hdrs = [ + "doorkey.h", "empty.h", + "impl/minigrid_doorkey_env.h", "impl/minigrid_empty_env.h", "impl/minigrid_env.h", "impl/utils.h", diff --git a/envpool/minigrid/doorkey.h b/envpool/minigrid/doorkey.h new file mode 100644 index 00000000..dd57730d --- /dev/null +++ b/envpool/minigrid/doorkey.h @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_DOORKEY_H_ +#define ENVPOOL_MINIGRID_DOORKEY_H_ + +#include + +#include "envpool/core/async_envpool.h" +#include "envpool/core/env.h" +#include "envpool/minigrid/impl/minigrid_doorkey_env.h" +#include "envpool/minigrid/impl/minigrid_env.h" + +namespace minigrid { + +class DoorKeyEnvFns { + public: + static decltype(auto) DefaultConfig() { + return MakeDict("size"_.Bind(8), + "agent_start_pos"_.Bind(std::pair(-1, -1)), + "agent_start_dir"_.Bind(-1), "agent_view_size"_.Bind(7)); + } + template + static decltype(auto) StateSpec(const Config& conf) { + int agent_view_size = conf["agent_view_size"_]; + int size = conf["size"_]; + return MakeDict("obs:direction"_.Bind(Spec({-1}, {0, 3})), + "obs:image"_.Bind(Spec( + {agent_view_size, agent_view_size, 3}, {0, 255})), + "info:agent_pos"_.Bind(Spec({2}, {0, size}))); + } + template + static decltype(auto) ActionSpec(const Config& conf) { + return MakeDict("action"_.Bind(Spec({-1}, {0, 6}))); + } +}; + +using DoorKeyEnvSpec = EnvSpec; +using FrameSpec = Spec; + +class DoorKeyEnv : public Env, public MiniGridDoorKeyEnv { + public: + DoorKeyEnv(const Spec& spec, int env_id) + : Env(spec, env_id), + MiniGridDoorKeyEnv( + spec.config["size"_], spec.config["agent_start_pos"_], + spec.config["agent_start_dir"_], spec.config["max_episode_steps"_], + spec.config["agent_view_size"_]) { + gen_ref_ = &gen_; + } + + bool IsDone() override { return done_; } + + void Reset() override { + MiniGridReset(); + WriteState(0.0); + } + + void Step(const Action& action) override { + int act = action["action"_]; + WriteState(MiniGridStep(static_cast(act))); + } + + private: + void WriteState(float reward) { + State state = Allocate(); + GenImage(state["obs:image"_]); + state["obs:direction"_] = agent_dir_; + state["reward"_] = reward; + state["info:agent_pos"_](0) = agent_pos_.first; + state["info:agent_pos"_](1) = agent_pos_.second; + } +}; + +using DoorKeyEnvPool = AsyncEnvPool; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_DOORKEY_H_ diff --git a/envpool/minigrid/impl/minigrid_doorkey_env.cc b/envpool/minigrid/impl/minigrid_doorkey_env.cc new file mode 100644 index 00000000..0f473365 --- /dev/null +++ b/envpool/minigrid/impl/minigrid_doorkey_env.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/minigrid/impl/minigrid_doorkey_env.h" + +#include +#include + +namespace minigrid { + +MiniGridDoorKeyEnv::MiniGridDoorKeyEnv(int size, + std::pair agent_start_pos, + int agent_start_dir, int max_steps, + int agent_view_size) { + width_ = size; + height_ = size; + agent_start_pos_ = agent_start_pos; + agent_start_dir_ = agent_start_dir; + see_through_walls_ = false; + max_steps_ = max_steps; + agent_view_size_ = agent_view_size; +} + +void MiniGridDoorKeyEnv::GenGrid() { + grid_.clear(); + for (int i = 0; i < height_; ++i) { + std::vector temp_vec(width_); + for (int j = 0; j < width_; ++j) { + temp_vec[j] = WorldObj(kEmpty); + } + grid_.emplace_back(temp_vec); + } + // generate the surrounding walls + for (int i = 0; i < width_; ++i) { + grid_[0][i] = WorldObj(kWall, kGrey); + grid_[height_ - 1][i] = WorldObj(kWall, kGrey); + } + for (int i = 0; i < height_; ++i) { + grid_[i][0] = WorldObj(kWall, kGrey); + grid_[i][width_ - 1] = WorldObj(kWall, kGrey); + } + // place a goal square in the bottom-right corner + grid_[height_ - 2][width_ - 2] = WorldObj(kGoal, kGreen); + // generate a vertical splitting wall + std::uniform_int_distribution<> x_dist(2, width_ - 3); + int x = x_dist(*gen_ref_); + for (int y = 0; y < height_; ++y) { + grid_[y][x] = WorldObj(kWall, kGrey); + } + // place the agent at a random position and orientation + // on the left side of the splitting wall + PlaceAgent(1, 1, x - 1, height_ - 2); + // place a door in the wall + std::uniform_int_distribution<> y_dist(1, height_ - 3); + int door_idx = y_dist(*gen_ref_); + grid_[door_idx][x] = WorldObj(kDoor, kYellow); + grid_[door_idx][x].SetDoorLocked(true); + grid_[door_idx][x].SetDoorOpen(false); + // place a yellow key on the left side + auto pos = PlaceObject(1, 1, x - 1, height_ - 2); + grid_[pos.second][pos.first] = WorldObj(kKey, kYellow); +} + +} // namespace minigrid diff --git a/envpool/minigrid/impl/minigrid_doorkey_env.h b/envpool/minigrid/impl/minigrid_doorkey_env.h new file mode 100644 index 00000000..4f5b091b --- /dev/null +++ b/envpool/minigrid/impl/minigrid_doorkey_env.h @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_ +#define ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_ + +#include + +#include "envpool/minigrid/impl/minigrid_env.h" + +namespace minigrid { + +class MiniGridDoorKeyEnv : public MiniGridEnv { + public: + MiniGridDoorKeyEnv(int size, std::pair agent_start_pos, + int agent_start_dir, int max_steps, int agent_view_size); + void GenGrid() override; +}; + +} // namespace minigrid + +#endif // ENVPOOL_MINIGRID_IMPL_MINIGRID_DOORKEY_ENV_H_ diff --git a/envpool/minigrid/impl/minigrid_env.cc b/envpool/minigrid/impl/minigrid_env.cc index c2ac26e6..152af5f4 100644 --- a/envpool/minigrid/impl/minigrid_env.cc +++ b/envpool/minigrid/impl/minigrid_env.cc @@ -136,23 +136,41 @@ void MiniGridEnv::PlaceAgent(int start_x, int start_y, int end_x, int end_y) { end_x = (end_x == -1) ? width_ - 1 : end_x; end_y = (end_y == -1) ? height_ - 1 : end_y; CHECK(start_x <= end_x && start_y <= end_y); + agent_pos_.first = -1; + agent_pos_.second = -1; + auto pos = PlaceObject(start_x, start_y, end_x, end_y); + agent_pos_.first = pos.first; + agent_pos_.second = pos.second; + // Randomly select a direction + if (agent_start_dir_ == -1) { + std::uniform_int_distribution<> dir_dist(0, 3); + agent_dir_ = dir_dist(*gen_ref_); + } +} + +// place an object where x-index in [start_x, end_x] and y-index in [start_y, +// end_y] return the desired position (x, y) +std::pair MiniGridEnv::PlaceObject(int start_x, int start_y, + int end_x, int end_y) { + std::pair result; std::uniform_int_distribution<> x_dist(start_x, end_x); std::uniform_int_distribution<> y_dist(start_y, end_y); while (true) { int x = x_dist(*gen_ref_); int y = y_dist(*gen_ref_); + // don't place the objwct on top of another object if (grid_[y][x].GetType() != kEmpty) { continue; } - agent_pos_.first = x; - agent_pos_.second = y; + // don't place the object where the agent is + if (agent_pos_.first == x && agent_pos_.second == y) { + continue; + } + result.first = x; + result.second = y; break; } - // Randomly select a direction - if (agent_start_dir_ == -1) { - std::uniform_int_distribution<> dir_dist(0, 3); - agent_dir_ = dir_dist(*gen_ref_); - } + return result; } void MiniGridEnv::GenImage(const Array& obs) { @@ -213,8 +231,44 @@ void MiniGridEnv::GenImage(const Array& obs) { std::fill(row.begin(), row.end(), 0); } if (!see_through_walls_) { - // TODO(siping): Process_vis vis_mask[agent_pos_y][agent_pos_x] = true; + for (int j = agent_view_size_ - 1; j >= 0; --j) { + // left -> right + for (int i = 0; i <= agent_view_size_ - 2; ++i) { + if (!vis_mask[j][i]) { + continue; + } + if (!agent_view_grid[j][i].CanSeeBehind()) { + continue; + } + vis_mask[j][i + 1] = true; + if (j > 0) { + vis_mask[j - 1][i + 1] = true; + vis_mask[j - 1][i] = true; + } + } + // right -> left + for (int i = agent_view_size_ - 1; i >= 1; --i) { + if (!vis_mask[j][i]) { + continue; + } + if (!agent_view_grid[j][i].CanSeeBehind()) { + continue; + } + vis_mask[j][i - 1] = true; + if (j > 0) { + vis_mask[j - 1][i - 1] = true; + vis_mask[j - 1][i] = true; + } + } + } + for (int j = 0; j < agent_view_size_; ++j) { + for (int i = 0; i < agent_view_size_; ++i) { + if (!vis_mask[j][i]) { + agent_view_grid[j][i] = WorldObj(kEmpty); + } + } + } } else { for (auto& row : vis_mask) { std::fill(row.begin(), row.end(), 1); diff --git a/envpool/minigrid/impl/minigrid_env.h b/envpool/minigrid/impl/minigrid_env.h index fadf8de0..e7dd6ff7 100644 --- a/envpool/minigrid/impl/minigrid_env.h +++ b/envpool/minigrid/impl/minigrid_env.h @@ -49,6 +49,8 @@ class MiniGridEnv { float MiniGridStep(Act act); void PlaceAgent(int start_x = 0, int start_y = 0, int end_x = -1, int end_y = -1); + std::pair PlaceObject(int start_x, int start_y, int end_x, + int end_y); void GenImage(const Array& obs); virtual void GenGrid() {} }; diff --git a/envpool/minigrid/impl/utils.h b/envpool/minigrid/impl/utils.h index 328146fa..5ff25c79 100644 --- a/envpool/minigrid/impl/utils.h +++ b/envpool/minigrid/impl/utils.h @@ -125,7 +125,7 @@ class WorldObj { [[nodiscard]] bool GetDoorOpen() const { return door_open_; } void SetDoorOpen(bool flag) { door_open_ = flag; } [[nodiscard]] bool GetDoorLocked() const { return door_locked_; } - void SetDoorLocker(bool flag) { door_locked_ = flag; } + void SetDoorLocked(bool flag) { door_locked_ = flag; } Type GetType() { return type_; } Color GetColor() { return color_; } int GetState() { diff --git a/envpool/minigrid/minigrid.cc b/envpool/minigrid/minigrid.cc index d0af3d07..e035ba93 100644 --- a/envpool/minigrid/minigrid.cc +++ b/envpool/minigrid/minigrid.cc @@ -13,9 +13,16 @@ // limitations under the License. #include "envpool/core/py_envpool.h" +#include "envpool/minigrid/doorkey.h" #include "envpool/minigrid/empty.h" using EmptyEnvSpec = PyEnvSpec; using EmptyEnvPool = PyEnvPool; -PYBIND11_MODULE(minigrid_envpool, m) { REGISTER(m, EmptyEnvSpec, EmptyEnvPool) } +using DoorKeyEnvSpec = PyEnvSpec; +using DoorKeyEnvPool = PyEnvPool; + +PYBIND11_MODULE(minigrid_envpool, m) { + REGISTER(m, EmptyEnvSpec, EmptyEnvPool) + REGISTER(m, DoorKeyEnvSpec, DoorKeyEnvPool) +} diff --git a/envpool/minigrid/registration.py b/envpool/minigrid/registration.py index 03d586f8..43540b0e 100644 --- a/envpool/minigrid/registration.py +++ b/envpool/minigrid/registration.py @@ -84,3 +84,14 @@ max_episode_steps=1024, size=16, ) + +register( + task_id="MiniGrid-DoorKey-5x5-v0", + import_path="envpool.minigrid", + spec_cls="DoorKeyEnvSpec", + dm_cls="DoorKeyDMEnvPool", + gym_cls="DoorKeyGymEnvPool", + gymnasium_cls="DoorKeyGymnasiumEnvPool", + max_episode_steps=250, + size=5, +)