Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Add menu_strs observations.
Browse files Browse the repository at this point in the history
This is an observation that renders windows that are menus that pop up.
  • Loading branch information
cdmatters committed Jul 15, 2021
1 parent 7832d01 commit 4f76029
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 7 deletions.
5 changes: 4 additions & 1 deletion include/nleobs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#define NLE_BLSTATS_SIZE 25
#define NLE_PROGRAM_STATE_SIZE 6
#define NLE_INTERNAL_SIZE 9
#define NLE_MISC_SIZE 3
#define NLE_MISC_SIZE 4
#define NLE_MENU_SIZE 24
#define NLE_MENU_STR_LENGTH 80
#define NLE_INVENTORY_SIZE 55
#define NLE_INVENTORY_STR_LENGTH 80
#define NLE_SCREEN_DESCRIPTION_LENGTH 80
Expand Down Expand Up @@ -39,6 +41,7 @@ typedef struct nle_observation {
signed char *tty_colors; /* Size NLE_TERM_LI * NLE_TERM_CO */
unsigned char *tty_cursor; /* Size 2 */
int *misc; /* Size NLE_MISC_SIZE */
unsigned char *menu_strs; /* Size NLE_MENU_SIZE * NLE_MENU_STR_LENGTH */
} nle_obs;

typedef struct {
Expand Down
6 changes: 6 additions & 0 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@
**nethack.OBSERVATION_DESC["misc"],
),
),
(
"menu_strs",
gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["menu_strs"]),
),
)


Expand Down Expand Up @@ -213,6 +217,8 @@ def __init__(
"tty_chars",
"tty_colors",
"tty_cursor",
"misc",
"menu_strs",
),
actions=None,
options=None,
Expand Down
1 change: 1 addition & 0 deletions nle/env/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def __init__(
"tty_colors",
"tty_cursor",
"misc",
"menu_strs",
),
no_progress_timeout: int = 10_000,
**kwargs,
Expand Down
5 changes: 5 additions & 0 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
PROGRAM_STATE_SHAPE = (_pynethack.nethack.NLE_PROGRAM_STATE_SIZE,)
INTERNAL_SHAPE = (_pynethack.nethack.NLE_INTERNAL_SIZE,)
MISC_SHAPE = (_pynethack.nethack.NLE_MISC_SIZE,)
MENU_STRS_SHAPE = (
_pynethack.nethack.NLE_MENU_SIZE,
_pynethack.nethack.NLE_MENU_STR_LENGTH,
)
INV_SIZE = (_pynethack.nethack.NLE_INVENTORY_SIZE,)
INV_STRS_SHAPE = (
_pynethack.nethack.NLE_INVENTORY_SIZE,
Expand Down Expand Up @@ -45,6 +49,7 @@
"tty_colors": dict(shape=TERMINAL_SHAPE, dtype=np.int8),
"tty_cursor": dict(shape=(2,), dtype=np.uint8),
"misc": dict(shape=MISC_SHAPE, dtype=np.int32),
"menu_strs": dict(shape=MENU_STRS_SHAPE, dtype=np.uint8),
}


Expand Down
6 changes: 3 additions & 3 deletions nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_misc_yn_question(self, game):
np.testing.assert_array_equal(misc, internal[1:4])

game.step(nethack.M("p")) # pray
np.testing.assert_array_equal(misc, np.array([1, 0, 0]))
np.testing.assert_array_equal(misc, np.array([1, 0, 0, 0]))
np.testing.assert_array_equal(misc, internal[1:4])

game.step(ord("n"))
Expand All @@ -490,7 +490,7 @@ def test_misc_getline(self, game):

game.step(nethack.M("n")) # name ..
game.step(ord("a")) # ... the current level
np.testing.assert_array_equal(misc, np.array([0, 1, 0]))
np.testing.assert_array_equal(misc, np.array([0, 1, 0, 0]))
np.testing.assert_array_equal(misc, internal[1:4])

for let in "Gehennom":
Expand All @@ -512,7 +512,7 @@ def test_misc_wait_for_space(self, game):
np.testing.assert_array_equal(misc, internal[1:4])

game.step(ord("i"))
np.testing.assert_array_equal(misc, np.array([0, 0, 1]))
np.testing.assert_array_equal(misc, np.array([0, 0, 1, 0]))
np.testing.assert_array_equal(misc, internal[1:4])

game.step(ord(" "))
Expand Down
13 changes: 10 additions & 3 deletions win/rl/pynethack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class Nethack
py::object inv_glyphs, py::object inv_letters,
py::object inv_oclasses, py::object inv_strs,
py::object screen_descriptions, py::object tty_chars,
py::object tty_colors, py::object tty_cursor, py::object misc)
py::object tty_colors, py::object tty_cursor, py::object misc,
py::object menu_strs)
{
std::vector<ssize_t> dungeon{ ROWNO, COLNO - 1 };
obs_.glyphs = checked_conversion<int16_t>(glyphs, dungeon);
Expand Down Expand Up @@ -160,6 +161,8 @@ class Nethack
tty_colors, { NLE_TERM_LI, NLE_TERM_CO });
obs_.tty_cursor = checked_conversion<uint8_t>(tty_cursor, { 2 });
obs_.misc = checked_conversion<int32_t>(misc, { NLE_MISC_SIZE });
obs_.menu_strs = checked_conversion<uint8_t>(
menu_strs, { NLE_MENU_SIZE, NLE_MENU_STR_LENGTH });

py_buffers_ = { std::move(glyphs),
std::move(chars),
Expand All @@ -177,7 +180,8 @@ class Nethack
std::move(tty_chars),
std::move(tty_colors),
std::move(tty_cursor),
std::move(misc) };
std::move(misc),
std::move(menu_strs) };
}

void
Expand Down Expand Up @@ -297,7 +301,8 @@ PYBIND11_MODULE(_pynethack, m)
py::arg("screen_descriptions") = py::none(),
py::arg("tty_chars") = py::none(),
py::arg("tty_colors") = py::none(),
py::arg("tty_cursor") = py::none(), py::arg("misc") = py::none())
py::arg("tty_cursor") = py::none(), py::arg("misc") = py::none(),
py::arg("menu_strs") = py::none())
.def("close", &Nethack::close)
.def("set_initial_seeds", &Nethack::set_initial_seeds)
.def("set_seeds", &Nethack::set_seeds)
Expand All @@ -314,6 +319,8 @@ PYBIND11_MODULE(_pynethack, m)
mn.attr("NLE_PROGRAM_STATE_SIZE") = py::int_(NLE_PROGRAM_STATE_SIZE);
mn.attr("NLE_INTERNAL_SIZE") = py::int_(NLE_INTERNAL_SIZE);
mn.attr("NLE_MISC_SIZE") = py::int_(NLE_MISC_SIZE);
mn.attr("NLE_MENU_SIZE") = py::int_(NLE_MENU_SIZE);
mn.attr("NLE_MENU_STR_LENGTH") = py::int_(NLE_MENU_STR_LENGTH);
mn.attr("NLE_INVENTORY_SIZE") = py::int_(NLE_INVENTORY_SIZE);
mn.attr("NLE_INVENTORY_STR_LENGTH") = py::int_(NLE_INVENTORY_STR_LENGTH);
mn.attr("NLE_SCREEN_DESCRIPTION_LENGTH") =
Expand Down
52 changes: 52 additions & 0 deletions win/rl/winrl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ NetHackRL::fill_obs(nle_obs *obs)
obs->misc[0] = in_yn_function;
obs->misc[1] = in_getlin;
obs->misc[2] = xwaitingforspace;
obs->misc[3] = false;
for (int i = 0; i < windows_.size(); ++i) {
// We have a Non-Inventory Menu Window
if (i != WIN_INVEN && windows_[i].get()->type == NHW_MENU) {
obs->misc[3] = true;
}
}
}

if ((!program_state.something_worth_saving && !program_state.in_moveloop)
Expand Down Expand Up @@ -353,6 +360,50 @@ NetHackRL::fill_obs(nle_obs *obs)
std::memset(obs->message, 0, NLE_MESSAGE_SIZE);
}
}

if (obs->menu_strs) {
bool found_menu = false;
for (int i = 0; i < windows_.size(); ++i) {
if (i != WIN_INVEN && windows_[i].get()->type == NHW_MENU) {
// We have a new menu window to be rendered on screen!
// The menu can either have 'rl_menu_items' in init, or
// simple text. In this case we expect only one of the two
// to be available, and this one goes into menu_strs
found_menu = true;
rl_window *win = windows_[i].get();
assert(!win->menu_items.empty() ^ !win->strings.empty());

int rows = max(win->menu_items.size(), win->strings.size());
int blank_rows = max(NLE_MENU_SIZE - rows, 0);
bool do_full_menu = !win->menu_items.empty();

for (int i = 0; i < NLE_MENU_SIZE; ++i) {
if (i >= rows)
break;
if (do_full_menu) {
// we would need to send over the mappings etc in
// different arrays
std::strncpy(
(char *) &obs->menu_strs[i * NLE_MENU_STR_LENGTH],
win->menu_items[i].str.c_str(),
NLE_MENU_STR_LENGTH);
} else {
// here we simply would put defaults or indicators
std::strncpy(
(char *) &obs->menu_strs[i * NLE_MENU_STR_LENGTH],
win->strings[i].c_str(), NLE_MENU_STR_LENGTH);
}
}
std::memset(&obs->menu_strs[rows * NLE_MENU_STR_LENGTH], 0,
blank_rows * NLE_MENU_STR_LENGTH);
}
}
if (!found_menu) {
std::memset(&obs->menu_strs[0], 0,
NLE_MENU_SIZE * NLE_MENU_STR_LENGTH);
}
}

if (obs->blstats) {
if (!u.dz) {
/* Tricky hack: On "You descend the stairs.--More--" we are
Expand Down Expand Up @@ -670,6 +721,7 @@ NetHackRL::destroy_nhwindow_method(winid wid)
{
DEBUG_API("rl_destroy_nhwindow(wid=" << wid << ")" << std::endl);
windows_[wid].reset(nullptr);
windows_.resize(wid);
tty_destroy_nhwindow(wid);
}

Expand Down

0 comments on commit 4f76029

Please sign in to comment.