Skip to content

Commit

Permalink
adding option to return ram state in info
Browse files Browse the repository at this point in the history
  • Loading branch information
renos committed Mar 30, 2024
1 parent f411fc2 commit 0bb3479
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions envpool/atari/atari_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AtariEnvFns {
"img_height"_.Bind(84), "img_width"_.Bind(84),
"task"_.Bind(std::string("pong")), "full_action_space"_.Bind(false),
"repeat_action_probability"_.Bind(0.0f),
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true));
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true), "expose_ram"_.Bind(false));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
Expand All @@ -66,7 +66,9 @@ class AtariEnvFns {
{0, 255})),
"info:lives"_.Bind(Spec<int>({-1})),
"info:reward"_.Bind(Spec<float>({-1})),
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})),
"info:ram"_.Bind(Spec<uint8_t>({128}, {0, 255}))
);
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down Expand Up @@ -99,6 +101,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
std::vector<Array> maxpool_buf_;
Array resize_img_;
std::uniform_int_distribution<> dist_noop_;
bool expose_ram_{false};
std::string rom_path_;

public:
Expand All @@ -121,6 +124,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
spec.config["img_width"_]}),
resize_img_(resize_spec_),
dist_noop_(0, spec.config["noop_max"_] - 1),
expose_ram_(spec.config["expose_ram"_]),
rom_path_(GetRomPath(spec.config["base_path"_], spec.config["task"_])) {
env_->setFloat("repeat_action_probability",
spec.config["repeat_action_probability"_]);
Expand Down Expand Up @@ -247,6 +251,23 @@ class AtariEnv : public Env<AtariEnvSpec> {
.Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3)
.Assign(stack_buf_[i]);
}
// Optionally add RAM state if expose_ram_ is true
if (expose_ram_) {
// const auto& ram = env_->getRAM(); // Get a reference to the RAM.
// const size_t ram_size = ram.size(); // Obtain the size of the RAM.
// const uint8_t* ram_data_ptr = ram.data();
// std::vector<uint8_t> ram_data(ram_data_ptr, ram_data_ptr + ram_size);
const size_t ram_size = env_->getRAM().size();
std::vector<uint8_t> ram_data(ram_size);

// Assuming getRAM().array() gives direct access to the RAM data
const uint8_t* ale_ram = env_->getRAM().array();
std::copy(ale_ram, ale_ram + ram_size, ram_data.begin());
state["info:ram"_].Assign(ale_ram, ram_size);
// for (size_t i = 0; i < ram_size; ++i) {
// state["ram"_].At(i) = ram[i]; // Directly write RAM data into state
// }
}
}

/**
Expand Down

0 comments on commit 0bb3479

Please sign in to comment.