Skip to content

Commit

Permalink
refactored "blocks" to "layers"
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoMeter committed Sep 10, 2024
1 parent 9a50c0a commit c4fde03
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
13 changes: 6 additions & 7 deletions cleanrl/ppo_trxl/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@
if not max_episode_steps:
max_episode_steps = env.max_episode_steps
if max_episode_steps <= 0:
max_episode_steps = 2048 # Memory Gym envs have max_episode_steps set to -1
max_episode_steps = 1024 # Memory Gym envs have max_episode_steps set to -1
# May episode impacts positional encoding, so make sure to set this accordingly

# Setup agent and load its model parameters
action_space_shape = (
(env.action_space.n,)
if isinstance(env.action_space, gym.spaces.Discrete)
else tuple(env.action_space.nvec)
)
(env.action_space.n,) if isinstance(env.action_space, gym.spaces.Discrete) else tuple(env.action_space.nvec)
)
agent = Agent(args, env.observation_space, action_space_shape, max_episode_steps)
agent.load_state_dict(checkpoint["model_weights"])

# Setup memory, mask and indices
memory = torch.zeros((1, max_episode_steps, args.trxl_num_blocks, args.trxl_dim), dtype=torch.float32)
memory = torch.zeros((1, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32)
memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1)
repetitions = torch.repeat_interleave(
torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0
Expand Down Expand Up @@ -58,5 +57,5 @@
done = termination or truncation
t += 1

print(info)
print(f"Episode return: {info['reward']}, Episode length: {info['length']}")
env.close()
28 changes: 14 additions & 14 deletions cleanrl/ppo_trxl/ppo_trxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Args:
"""the number of parallel game environments"""
num_steps: int = 512
"""the number of steps to run in each environment per policy rollout"""
anneal_steps: int = 32*512*10000
anneal_steps: int = 32 * 512 * 10000
"""the number of steps to linearly anneal the learning rate and entropy coefficient from initial to final"""
gamma: float = 0.995
"""the discount factor gamma"""
Expand Down Expand Up @@ -80,8 +80,8 @@ class Args:
"""the target KL divergence threshold"""

# Transformer-XL specific arguments
trxl_num_blocks: int = 3
"""the number of transformer blocks"""
trxl_num_layers: int = 3
"""the number of transformer layers"""
trxl_num_heads: int = 4
"""the number of heads used in multi-head attention"""
trxl_dim: int = 384
Expand Down Expand Up @@ -194,7 +194,7 @@ def forward(self, values, keys, query, mask):
return self.fc_out(out), attention


class TransformerBlock(nn.Module):
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.attention = MultiHeadAttention(dim, num_heads)
Expand All @@ -217,29 +217,29 @@ def forward(self, value, key, query, mask):


class Transformer(nn.Module):
def __init__(self, num_blocks, dim, num_heads, max_episode_steps, positional_encoding):
def __init__(self, num_layers, dim, num_heads, max_episode_steps, positional_encoding):
super().__init__()
self.max_episode_steps = max_episode_steps
self.positional_encoding = positional_encoding
if positional_encoding == "absolute":
self.pos_embedding = PositionalEncoding(dim)
elif positional_encoding == "learned":
self.pos_embedding = nn.Parameter(torch.randn(max_episode_steps, dim))
self.transformer_blocks = nn.ModuleList([TransformerBlock(dim, num_heads) for _ in range(num_blocks)])
self.transformer_layers = nn.ModuleList([TransformerLayer(dim, num_heads) for _ in range(num_layers)])

def forward(self, x, memories, mask, memory_indices):
# Add positional encoding to every transformer block input
# Add positional encoding to every transformer layer input
if self.positional_encoding == "absolute":
pos_embedding = self.pos_embedding(self.max_episode_steps)[memory_indices]
memories = memories + pos_embedding.unsqueeze(2)
elif self.positional_encoding == "learned":
memories = memories + self.pos_embedding[memory_indices].unsqueeze(2)

# Forward transformer blocks and return new memories (i.e. hidden states)
# Forward transformer layers and return new memories (i.e. hidden states)
out_memories = []
for i, block in enumerate(self.transformer_blocks):
for i, layer in enumerate(self.transformer_layers):
out_memories.append(x.detach())
x, attention_weights = block(
x, attention_weights = layer(
memories[:, :, i], memories[:, :, i], x.unsqueeze(1), mask
) # args: value, key, query, mask
x = x.squeeze()
Expand Down Expand Up @@ -270,7 +270,7 @@ def __init__(self, args, observation_space, action_space_shape, max_episode_step
self.encoder = layer_init(nn.Linear(observation_space.shape[0], args.trxl_dim))

self.transformer = Transformer(
args.trxl_num_blocks, args.trxl_dim, args.trxl_num_heads, self.max_episode_steps, args.trxl_positional_encoding
args.trxl_num_layers, args.trxl_dim, args.trxl_num_heads, self.max_episode_steps, args.trxl_positional_encoding
)

self.hidden_post_trxl = nn.Sequential(
Expand Down Expand Up @@ -402,7 +402,7 @@ def reconstruct_observation(self):
log_probs = torch.zeros((args.num_steps, args.num_envs, len(action_space_shape)))
values = torch.zeros((args.num_steps, args.num_envs))
# The length of stored-memories is equal to the number of sampled episodes during training data sampling
# (num_episodes, max_episode_length, num_blocks, embed_dim)
# (num_episodes, max_episode_length, num_layers, embed_dim)
stored_memories = []
# Memory mask used during attention
stored_memory_masks = torch.zeros((args.num_steps, args.num_envs, args.trxl_memory_length), dtype=torch.bool)
Expand All @@ -419,7 +419,7 @@ def reconstruct_observation(self):
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs)
# Setup placeholders for each environments's current episodic memory
next_memory = torch.zeros((args.num_envs, max_episode_steps, args.trxl_num_blocks, args.trxl_dim), dtype=torch.float32)
next_memory = torch.zeros((args.num_envs, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32)
# Generate episodic memory mask used in attention
memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1)
""" e.g. memory mask tensor looks like this if memory_length = 6
Expand Down Expand Up @@ -498,7 +498,7 @@ def reconstruct_observation(self):
stored_memories[mem_index] = stored_memories[mem_index].clone()
# Reset episodic memory
next_memory[id] = torch.zeros(
(max_episode_steps, args.trxl_num_blocks, args.trxl_dim), dtype=torch.float32
(max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32
)
if step < args.num_steps - 1:
# Store memory inside the buffer
Expand Down

0 comments on commit c4fde03

Please sign in to comment.