From 0bb347910a0798a54592b89b3c3816a21f6ebefd Mon Sep 17 00:00:00 2001 From: Renos Zabounidis Date: Sat, 30 Mar 2024 00:01:47 -0700 Subject: [PATCH] adding option to return ram state in info --- envpool/atari/atari_env.h | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/envpool/atari/atari_env.h b/envpool/atari/atari_env.h index 5c4bb992..b7710440 100644 --- a/envpool/atari/atari_env.h +++ b/envpool/atari/atari_env.h @@ -56,7 +56,7 @@ class AtariEnvFns { "img_height"_.Bind(84), "img_width"_.Bind(84), "task"_.Bind(std::string("pong")), "full_action_space"_.Bind(false), "repeat_action_probability"_.Bind(0.0f), - "use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true)); + "use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true), "expose_ram"_.Bind(false)); } template static decltype(auto) StateSpec(const Config& conf) { @@ -66,7 +66,9 @@ class AtariEnvFns { {0, 255})), "info:lives"_.Bind(Spec({-1})), "info:reward"_.Bind(Spec({-1})), - "info:terminated"_.Bind(Spec({-1}, {0, 1}))); + "info:terminated"_.Bind(Spec({-1}, {0, 1})), + "info:ram"_.Bind(Spec({128}, {0, 255})) + ); } template static decltype(auto) ActionSpec(const Config& conf) { @@ -99,6 +101,7 @@ class AtariEnv : public Env { std::vector maxpool_buf_; Array resize_img_; std::uniform_int_distribution<> dist_noop_; + bool expose_ram_{false}; std::string rom_path_; public: @@ -121,6 +124,7 @@ class AtariEnv : public Env { spec.config["img_width"_]}), resize_img_(resize_spec_), dist_noop_(0, spec.config["noop_max"_] - 1), + expose_ram_(spec.config["expose_ram"_]), rom_path_(GetRomPath(spec.config["base_path"_], spec.config["task"_])) { env_->setFloat("repeat_action_probability", spec.config["repeat_action_probability"_]); @@ -247,6 +251,23 @@ class AtariEnv : public Env { .Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3) .Assign(stack_buf_[i]); } + // Optionally add RAM state if expose_ram_ is true + if (expose_ram_) { + // const auto& ram = env_->getRAM(); // Get a reference to the RAM. + // const size_t ram_size = ram.size(); // Obtain the size of the RAM. + // const uint8_t* ram_data_ptr = ram.data(); + // std::vector ram_data(ram_data_ptr, ram_data_ptr + ram_size); + const size_t ram_size = env_->getRAM().size(); + std::vector ram_data(ram_size); + + // Assuming getRAM().array() gives direct access to the RAM data + const uint8_t* ale_ram = env_->getRAM().array(); + std::copy(ale_ram, ale_ram + ram_size, ram_data.begin()); + state["info:ram"_].Assign(ale_ram, ram_size); + // for (size_t i = 0; i < ram_size; ++i) { + // state["ram"_].At(i) = ram[i]; // Directly write RAM data into state + // } + } } /**