From 4f7602920ac6e5a6e41f1b71ac8b872ddcdd3b44 Mon Sep 17 00:00:00 2001 From: Eric Hambro Date: Mon, 12 Jul 2021 09:58:06 -0700 Subject: [PATCH] Add menu_strs observations. This is an observation that renders windows that are menus that pop up. --- include/nleobs.h | 5 +++- nle/env/base.py | 6 +++++ nle/env/tasks.py | 1 + nle/nethack/nethack.py | 5 ++++ nle/tests/test_nethack.py | 6 ++--- win/rl/pynethack.cc | 13 +++++++--- win/rl/winrl.cc | 52 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 81 insertions(+), 7 deletions(-) diff --git a/include/nleobs.h b/include/nleobs.h index ffe72621d..1b99da3a9 100644 --- a/include/nleobs.h +++ b/include/nleobs.h @@ -6,7 +6,9 @@ #define NLE_BLSTATS_SIZE 25 #define NLE_PROGRAM_STATE_SIZE 6 #define NLE_INTERNAL_SIZE 9 -#define NLE_MISC_SIZE 3 +#define NLE_MISC_SIZE 4 +#define NLE_MENU_SIZE 24 +#define NLE_MENU_STR_LENGTH 80 #define NLE_INVENTORY_SIZE 55 #define NLE_INVENTORY_STR_LENGTH 80 #define NLE_SCREEN_DESCRIPTION_LENGTH 80 @@ -39,6 +41,7 @@ typedef struct nle_observation { signed char *tty_colors; /* Size NLE_TERM_LI * NLE_TERM_CO */ unsigned char *tty_cursor; /* Size 2 */ int *misc; /* Size NLE_MISC_SIZE */ + unsigned char *menu_strs; /* Size NLE_MENU_SIZE * NLE_MENU_STR_LENGTH */ } nle_obs; typedef struct { diff --git a/nle/env/base.py b/nle/env/base.py index b17ab7a97..83790b61a 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -138,6 +138,10 @@ **nethack.OBSERVATION_DESC["misc"], ), ), + ( + "menu_strs", + gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["menu_strs"]), + ), ) @@ -213,6 +217,8 @@ def __init__( "tty_chars", "tty_colors", "tty_cursor", + "misc", + "menu_strs", ), actions=None, options=None, diff --git a/nle/env/tasks.py b/nle/env/tasks.py index 68cea76b3..25727613e 100644 --- a/nle/env/tasks.py +++ b/nle/env/tasks.py @@ -321,6 +321,7 @@ def __init__( "tty_colors", "tty_cursor", "misc", + "menu_strs", ), no_progress_timeout: int = 10_000, **kwargs, diff --git a/nle/nethack/nethack.py b/nle/nethack/nethack.py index fe4892872..037a951b4 100644 --- a/nle/nethack/nethack.py +++ b/nle/nethack/nethack.py @@ -17,6 +17,10 @@ PROGRAM_STATE_SHAPE = (_pynethack.nethack.NLE_PROGRAM_STATE_SIZE,) INTERNAL_SHAPE = (_pynethack.nethack.NLE_INTERNAL_SIZE,) MISC_SHAPE = (_pynethack.nethack.NLE_MISC_SIZE,) +MENU_STRS_SHAPE = ( + _pynethack.nethack.NLE_MENU_SIZE, + _pynethack.nethack.NLE_MENU_STR_LENGTH, +) INV_SIZE = (_pynethack.nethack.NLE_INVENTORY_SIZE,) INV_STRS_SHAPE = ( _pynethack.nethack.NLE_INVENTORY_SIZE, @@ -45,6 +49,7 @@ "tty_colors": dict(shape=TERMINAL_SHAPE, dtype=np.int8), "tty_cursor": dict(shape=(2,), dtype=np.uint8), "misc": dict(shape=MISC_SHAPE, dtype=np.int32), + "menu_strs": dict(shape=MENU_STRS_SHAPE, dtype=np.uint8), } diff --git a/nle/tests/test_nethack.py b/nle/tests/test_nethack.py index 3e3e94cd3..684d71e05 100644 --- a/nle/tests/test_nethack.py +++ b/nle/tests/test_nethack.py @@ -472,7 +472,7 @@ def test_misc_yn_question(self, game): np.testing.assert_array_equal(misc, internal[1:4]) game.step(nethack.M("p")) # pray - np.testing.assert_array_equal(misc, np.array([1, 0, 0])) + np.testing.assert_array_equal(misc, np.array([1, 0, 0, 0])) np.testing.assert_array_equal(misc, internal[1:4]) game.step(ord("n")) @@ -490,7 +490,7 @@ def test_misc_getline(self, game): game.step(nethack.M("n")) # name .. game.step(ord("a")) # ... the current level - np.testing.assert_array_equal(misc, np.array([0, 1, 0])) + np.testing.assert_array_equal(misc, np.array([0, 1, 0, 0])) np.testing.assert_array_equal(misc, internal[1:4]) for let in "Gehennom": @@ -512,7 +512,7 @@ def test_misc_wait_for_space(self, game): np.testing.assert_array_equal(misc, internal[1:4]) game.step(ord("i")) - np.testing.assert_array_equal(misc, np.array([0, 0, 1])) + np.testing.assert_array_equal(misc, np.array([0, 0, 1, 0])) np.testing.assert_array_equal(misc, internal[1:4]) game.step(ord(" ")) diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index 732a5abea..853823f29 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -129,7 +129,8 @@ class Nethack py::object inv_glyphs, py::object inv_letters, py::object inv_oclasses, py::object inv_strs, py::object screen_descriptions, py::object tty_chars, - py::object tty_colors, py::object tty_cursor, py::object misc) + py::object tty_colors, py::object tty_cursor, py::object misc, + py::object menu_strs) { std::vector dungeon{ ROWNO, COLNO - 1 }; obs_.glyphs = checked_conversion(glyphs, dungeon); @@ -160,6 +161,8 @@ class Nethack tty_colors, { NLE_TERM_LI, NLE_TERM_CO }); obs_.tty_cursor = checked_conversion(tty_cursor, { 2 }); obs_.misc = checked_conversion(misc, { NLE_MISC_SIZE }); + obs_.menu_strs = checked_conversion( + menu_strs, { NLE_MENU_SIZE, NLE_MENU_STR_LENGTH }); py_buffers_ = { std::move(glyphs), std::move(chars), @@ -177,7 +180,8 @@ class Nethack std::move(tty_chars), std::move(tty_colors), std::move(tty_cursor), - std::move(misc) }; + std::move(misc), + std::move(menu_strs) }; } void @@ -297,7 +301,8 @@ PYBIND11_MODULE(_pynethack, m) py::arg("screen_descriptions") = py::none(), py::arg("tty_chars") = py::none(), py::arg("tty_colors") = py::none(), - py::arg("tty_cursor") = py::none(), py::arg("misc") = py::none()) + py::arg("tty_cursor") = py::none(), py::arg("misc") = py::none(), + py::arg("menu_strs") = py::none()) .def("close", &Nethack::close) .def("set_initial_seeds", &Nethack::set_initial_seeds) .def("set_seeds", &Nethack::set_seeds) @@ -314,6 +319,8 @@ PYBIND11_MODULE(_pynethack, m) mn.attr("NLE_PROGRAM_STATE_SIZE") = py::int_(NLE_PROGRAM_STATE_SIZE); mn.attr("NLE_INTERNAL_SIZE") = py::int_(NLE_INTERNAL_SIZE); mn.attr("NLE_MISC_SIZE") = py::int_(NLE_MISC_SIZE); + mn.attr("NLE_MENU_SIZE") = py::int_(NLE_MENU_SIZE); + mn.attr("NLE_MENU_STR_LENGTH") = py::int_(NLE_MENU_STR_LENGTH); mn.attr("NLE_INVENTORY_SIZE") = py::int_(NLE_INVENTORY_SIZE); mn.attr("NLE_INVENTORY_STR_LENGTH") = py::int_(NLE_INVENTORY_STR_LENGTH); mn.attr("NLE_SCREEN_DESCRIPTION_LENGTH") = diff --git a/win/rl/winrl.cc b/win/rl/winrl.cc index 836b63a7c..f1136eb76 100644 --- a/win/rl/winrl.cc +++ b/win/rl/winrl.cc @@ -290,6 +290,13 @@ NetHackRL::fill_obs(nle_obs *obs) obs->misc[0] = in_yn_function; obs->misc[1] = in_getlin; obs->misc[2] = xwaitingforspace; + obs->misc[3] = false; + for (int i = 0; i < windows_.size(); ++i) { + // We have a Non-Inventory Menu Window + if (i != WIN_INVEN && windows_[i].get()->type == NHW_MENU) { + obs->misc[3] = true; + } + } } if ((!program_state.something_worth_saving && !program_state.in_moveloop) @@ -353,6 +360,50 @@ NetHackRL::fill_obs(nle_obs *obs) std::memset(obs->message, 0, NLE_MESSAGE_SIZE); } } + + if (obs->menu_strs) { + bool found_menu = false; + for (int i = 0; i < windows_.size(); ++i) { + if (i != WIN_INVEN && windows_[i].get()->type == NHW_MENU) { + // We have a new menu window to be rendered on screen! + // The menu can either have 'rl_menu_items' in init, or + // simple text. In this case we expect only one of the two + // to be available, and this one goes into menu_strs + found_menu = true; + rl_window *win = windows_[i].get(); + assert(!win->menu_items.empty() ^ !win->strings.empty()); + + int rows = max(win->menu_items.size(), win->strings.size()); + int blank_rows = max(NLE_MENU_SIZE - rows, 0); + bool do_full_menu = !win->menu_items.empty(); + + for (int i = 0; i < NLE_MENU_SIZE; ++i) { + if (i >= rows) + break; + if (do_full_menu) { + // we would need to send over the mappings etc in + // different arrays + std::strncpy( + (char *) &obs->menu_strs[i * NLE_MENU_STR_LENGTH], + win->menu_items[i].str.c_str(), + NLE_MENU_STR_LENGTH); + } else { + // here we simply would put defaults or indicators + std::strncpy( + (char *) &obs->menu_strs[i * NLE_MENU_STR_LENGTH], + win->strings[i].c_str(), NLE_MENU_STR_LENGTH); + } + } + std::memset(&obs->menu_strs[rows * NLE_MENU_STR_LENGTH], 0, + blank_rows * NLE_MENU_STR_LENGTH); + } + } + if (!found_menu) { + std::memset(&obs->menu_strs[0], 0, + NLE_MENU_SIZE * NLE_MENU_STR_LENGTH); + } + } + if (obs->blstats) { if (!u.dz) { /* Tricky hack: On "You descend the stairs.--More--" we are @@ -670,6 +721,7 @@ NetHackRL::destroy_nhwindow_method(winid wid) { DEBUG_API("rl_destroy_nhwindow(wid=" << wid << ")" << std::endl); windows_[wid].reset(nullptr); + windows_.resize(wid); tty_destroy_nhwindow(wid); }