Skip to content

Commit

Permalink
fixed reset hxs state when episode is finished
Browse files Browse the repository at this point in the history
  • Loading branch information
lehduong committed Jul 3, 2020
1 parent d5ea6e2 commit 14daf9f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 224 deletions.
80 changes: 74 additions & 6 deletions core/algorithms/lacie/base_lacie.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,48 @@ def compute_contrastive_loss(self, rollouts, advantages):
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = advantages.shape

# INPUT SEQUENCES
# INPUT SEQUENCES AND MASKS
# the stochastic input will be defined by last 2 scalar
input_seq = rollouts.obs[1:, :, -2:]
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
# reverse the input seq order since we want to compute from right to left
input_seq = torch.flip(input_seq, [0])
masks = torch.flip(masks, [0])
# encode the input sequence
# n_steps x n_processes x hidden_dim
input_seq, _ = self.input_seq_encoder(input_seq)
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = ((masks[1:] == 0.0)
.any(dim=-1)
.nonzero()
.squeeze()
.cpu())

# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
# Deal with scalar
has_zeros = [has_zeros.item() + 1]
else:
has_zeros = (has_zeros + 1).numpy().tolist()

# add t=0 and t=T to the list
has_zeros = [-1] + has_zeros + [num_steps - 1]

outputs = []

for i in range(len(has_zeros) - 1):
# We can now process steps that don't have any zeros in masks together!
# This is much faster
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]

output, _ = self.input_seq_encoder(
input_seq[start_idx + 1: end_idx + 1])

outputs.append(output)

# x is a (T, N, -1) tensor
input_seq = torch.cat(outputs, dim=0)
assert len(input_seq) == num_steps
# reverse back
input_seq = torch.flip(input_seq, [0])

Expand Down Expand Up @@ -159,14 +193,48 @@ def compute_weighted_advantages(self, rollouts, advantages):
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = advantages.shape

# INPUT SEQUENCES
# INPUT SEQUENCES AND MASKS
# the stochastic input will be defined by last 2 scalar
input_seq = rollouts.obs[1:, :, -2:]
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
# reverse the input seq order since we want to compute from right to left
input_seq = torch.flip(input_seq, [0])
masks = torch.flip(masks, [0])
# encode the input sequence
# output shape: n_steps x n_processes x hidden_dim
input_seq, _ = self.input_seq_encoder(input_seq)
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = ((masks[1:] == 0.0)
.any(dim=-1)
.nonzero()
.squeeze()
.cpu())

# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
# Deal with scalar
has_zeros = [has_zeros.item() + 1]
else:
has_zeros = (has_zeros + 1).numpy().tolist()

# add t=0 and t=T to the list
has_zeros = [-1] + has_zeros + [num_steps - 1]

outputs = []

for i in range(len(has_zeros) - 1):
# We can now process steps that don't have any zeros in masks together!
# This is much faster
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]

output, _ = self.input_seq_encoder(
input_seq[start_idx + 1: end_idx + 1])

outputs.append(output)

# x is a (T, N, -1) tensor
input_seq = torch.cat(outputs, dim=0)
assert len(input_seq) == num_steps
# reverse back
input_seq = torch.flip(input_seq, [0])

Expand Down
218 changes: 0 additions & 218 deletions train_imitation_learning.py

This file was deleted.

0 comments on commit 14daf9f

Please sign in to comment.