Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

[WIP] Libnethack shared #140

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ endif()

message(STATUS "Building nle backend version: ${NLE_VERSION}")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

# We use this to decide where the root of the nle/ package is. Normally it
Expand Down Expand Up @@ -99,10 +100,12 @@ target_link_directories(nethack PUBLIC /usr/local/lib)
target_link_libraries(nethack PUBLIC m fcontext bz2)

# dlopen wrapper library
add_library(nethackdl STATIC "sys/unix/nledl.c")
add_library(nethackdl STATIC "sys/unix/nledl.c" "sys/unix/nleshared.cc")
target_include_directories(nethackdl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(nethackdl PUBLIC dl)

#set_target_properties(nethackdl PROPERTIES CMAKE_CXX_STANDARD 17)

# rlmain C++ (test) binary
add_executable(rlmain "sys/unix/rlmain.cc")
set_target_properties(rlmain PROPERTIES CXX_STANDARD 11)
Expand Down
14 changes: 5 additions & 9 deletions include/nledl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,9 @@
#include "nleobs.h"

/* TODO: Don't call this nle_ctx_t as well. */
typedef struct nledl_ctx {
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*step)(void *, nle_obs *);
FILE *ttyrec;
} nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *);
typedef struct nledl_ctx nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *, int shared);
nle_ctx_t *nle_step(nle_ctx_t *, nle_obs *);

void nle_reset(nle_ctx_t *, nle_obs *, FILE *, nle_seeds_init_t *);
Expand All @@ -27,4 +21,6 @@ void nle_end(nle_ctx_t *);
void nle_set_seed(nle_ctx_t *, unsigned long, unsigned long, char);
void nle_get_seed(nle_ctx_t *, unsigned long *, unsigned long *, char *);

int nle_supports_shared(void);

#endif /* NLEDL_H */
80 changes: 53 additions & 27 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def _set_env_vars(options, hackdir, wizkit=None):
if wizkit is not None:
os.environ["WIZKIT"] = wizkit

_nhinstances = 0

# TODO: Not thread-safe for many reasons.
# TODO: On Linux, we could use dlmopen to use different linker namespaces,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably update some of these comments when we merge this.

# which should allow several instances of this. On MacOS, that seems
# a tough call.
class Nethack:
_instances = 0

def __init__(
self,
Expand All @@ -93,34 +93,56 @@ def __init__(
wizard=False,
hackdir=HACKDIR,
):
global _nhinstances
self._copy = copy

_nhinstances = _nhinstances + 1

if not os.path.exists(hackdir) or not os.path.exists(
os.path.join(hackdir, "sysconf")
):
raise FileNotFoundError(
"Couldn't find NetHack installation at '%s'." % hackdir
)

# Create a HACKDIR for us.
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name

# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()
self.shared = False;

# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))
if _pynethack.supports_shared():
# "shared" mode does some hacky things to enable using a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mix of 2-letter and 4-letter indentation is weeeiiiirdd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is. I apologize. I come from a world of 2 indentation. Formatting will be cleaned up.

# shared libnethack.so, prevents writing to any files, and does
# not chdir.
self.shared = True
dlpath = DLPATH
self._hackdir = hackdir
else:

# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)
# Create a HACKDIR for us.
if os.getenv("SLURM_JOBID"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slurm stuff probably doesn't belong here and I'll most likely remove it from this PR

self._tempdir = None
self._vardir = "/scratch/slurm_tmpdir/%s/nle-%d-%d" % (os.getenv("SLURM_JOBID"), os.getpid(), _nhinstances)
os.mkdir(self._vardir)
else:
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name
print("_vardir is ", self._vardir)

self._hackdir = self._vardir

# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()

# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))

# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)

if options is None:
options = NETHACKOPTIONS
Expand All @@ -129,10 +151,10 @@ def __init__(
self._options.append("playmode:debug")
self._wizard = wizard

_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
self._ttyrec = ttyrec

self._pynethack = _pynethack.Nethack(dlpath, ttyrec)
self._pynethack = _pynethack.Nethack(dlpath, ttyrec, self.shared)

self._obs_buffers = {}

Expand All @@ -154,6 +176,8 @@ def step(self, action):
return self._step_return(), self._pynethack.done()

def _write_wizkit_file(self, wizkit_items):
if self._vardir is None:
raise RuntimeError("FIXME: shared wizkit: can't write to HACKDIR as it is a shared directory")
heiner marked this conversation as resolved.
Show resolved Hide resolved
# TODO ideally we need to check the validity of the requested items
with open(os.path.join(self._vardir, WIZKIT_FNAME), "w") as f:
for item in wizkit_items:
Expand All @@ -164,9 +188,9 @@ def reset(self, new_ttyrec=None, wizkit_items=None):
if not self._wizard:
raise ValueError("Set wizard=True to use the wizkit option.")
self._write_wizkit_file(wizkit_items)
_set_env_vars(self._options, self._vardir, wizkit=WIZKIT_FNAME)
_set_env_vars(self._options, self._hackdir, wizkit=WIZKIT_FNAME)
else:
_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
if new_ttyrec is None:
self._pynethack.reset()
else:
Expand All @@ -178,11 +202,13 @@ def reset(self, new_ttyrec=None, wizkit_items=None):

def close(self):
self._pynethack.close()
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
self._tempdir.cleanup()
if not self.shared:
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
if self._tempdir is not None:
self._tempdir.cleanup()

def set_initial_seeds(self, core, disp, reseed=False):
self._pynethack.set_initial_seeds(core, disp, reseed)
Expand Down
141 changes: 89 additions & 52 deletions sys/unix/nledl.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,70 +6,102 @@

#include "nledl.h"

void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init)
{
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);

if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}

dlerror(); /* Clear any existing error */

#if defined(__linux__) && defined(__x86_64__)
#define HASSHARED
heiner marked this conversation as resolved.
Show resolved Hide resolved
#endif

void* nleshared_open(const char *dlpath);
void nleshared_close(void* handle);
void nleshared_reset(void* handle);
void* nleshared_sym(void* handle, const char* symname);

typedef struct nledl_ctx {
void* shared;
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *);
start = dlsym(nledl->dlhandle, "nle_start");
nledl->nle_ctx = start(obs, nledl->ttyrec, seed_init);
void *(*step)(void *, nle_obs *);
void (*end)(void *);
FILE *ttyrec;
} nle_ctx_t;

static void* sym(nle_ctx_t *nledl, const char* name) {
if (nledl->shared) {
return nleshared_sym(nledl->shared, name);
} else {
dlerror(); /* Clear any existing error */
void* r = dlsym(nledl->dlhandle, name);
char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
return r;
}
}

nledl->step = dlsym(nledl->dlhandle, "nle_step");

error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init, int shared)
{
nledl->shared = NULL;
if (shared) {
#ifdef HASSHARED
nledl->shared = nleshared_open(nledl->dlpath);
#else
fprintf(stderr, "Shared mode not supported on this system!\n");
exit(EXIT_FAILURE);
#endif
} else {
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);
if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}
}


nledl->start = sym(nledl, "nle_start");
nledl->step = sym(nledl, "nle_step");
nledl->end = sym(nledl, "nle_end");

nledl->nle_ctx = nledl->start(obs, nledl->ttyrec, seed_init);
}

void
nledl_close(nle_ctx_t *nledl)
{
void (*end)(void *);
nledl->end(nledl->nle_ctx);

end = dlsym(nledl->dlhandle, "nle_end");
end(nledl->nle_ctx);
if (nledl->shared) {
nleshared_close(nledl->shared);
} else {
if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
}

if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
dlerror();
}

dlerror();
}

nle_ctx_t *
nle_start(const char *dlpath, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
nle_seeds_init_t *seed_init, int shared)
{
/* TODO: Consider getting ttyrec path from caller? */
struct nledl_ctx *nledl = malloc(sizeof(struct nledl_ctx));
nledl->ttyrec = ttyrec;
strncpy(nledl->dlpath, dlpath, sizeof(nledl->dlpath));

nledl_init(nledl, obs, seed_init);
nledl_init(nledl, obs, seed_init, shared);
return nledl;
};

nle_ctx_t *
nle_step(nle_ctx_t *nledl, nle_obs *obs)
{
if (!nledl || !nledl->dlhandle || !nledl->nle_ctx) {
if (!nledl || (!nledl->dlhandle && !nledl->shared) || !nledl->nle_ctx) {
fprintf(stderr, "Illegal nledl_ctx\n");
exit(EXIT_FAILURE);
}
Expand All @@ -85,14 +117,22 @@ void
nle_reset(nle_ctx_t *nledl, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
{
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init);
if (nledl->shared) {
nledl->end(nledl->nle_ctx);
nleshared_reset(nledl->shared);
if (ttyrec)
nledl->ttyrec = ttyrec;
nledl->nle_ctx = nledl->start(obs, ttyrec, seed_init);
} else {
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init, 0);
}
}

void
Expand All @@ -108,13 +148,7 @@ nle_set_seed(nle_ctx_t *nledl, unsigned long core, unsigned long disp,
{
void (*set_seed)(void *, unsigned long, unsigned long, char);

set_seed = dlsym(nledl->dlhandle, "nle_set_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
set_seed = sym(nledl, "nle_set_seed");

set_seed(nledl->nle_ctx, core, disp, reseed);
}
Expand All @@ -125,16 +159,19 @@ nle_get_seed(nle_ctx_t *nledl, unsigned long *core, unsigned long *disp,
{
void (*get_seed)(void *, unsigned long *, unsigned long *, char *);

get_seed = dlsym(nledl->dlhandle, "nle_get_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
get_seed = sym(nledl, "nle_get_seed");

/* Careful here. NetHack has different ideas of what a boolean is
* than C++ (see global.h and SKIP_BOOLEAN). But one byte should be fine.
*/
get_seed(nledl->nle_ctx, core, disp, reseed);
}

int
nle_supports_shared(void) {
#ifdef HASSHARED
return 1;
#else
return 0;
#endif
}
Loading