Skip to content

Commit

Permalink
fixed has_zeros input in lacie + refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lehduong committed Jul 3, 2020
1 parent 14daf9f commit 293ab82
Showing 1 changed file with 52 additions and 84 deletions.
136 changes: 52 additions & 84 deletions core/algorithms/lacie/base_lacie.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,8 @@ def __init__(self,
self.softmax = nn.Softmax(dim=0)
self.log_softmax = nn.LogSoftmax(dim=0)

def compute_contrastive_loss(self, rollouts, advantages):
"""
Contrastive Predictive Coding for learning representation and density ratio
:param rollouts: Storage's instance
:param advantage: tensor of shape: (timestep, n_processes, 1)
"""
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = advantages.shape

def _encode_input_sequences(self, rollouts):
num_steps, n_processes, _ = rollouts.actions.shape
# INPUT SEQUENCES AND MASKS
# the stochastic input will be defined by last 2 scalar
input_seq = rollouts.obs[1:, :, -2:]
Expand All @@ -104,7 +97,7 @@ def compute_contrastive_loss(self, rollouts, advantages):
# encode the input sequence
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = ((masks[1:] == 0.0)
has_zeros = ((masks[:-1] == 0.0)
.any(dim=-1)
.nonzero()
.squeeze()
Expand All @@ -128,8 +121,13 @@ def compute_contrastive_loss(self, rollouts, advantages):
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]

output, _ = self.input_seq_encoder(
input_seq[start_idx + 1: end_idx + 1])
try:
output, _ = self.input_seq_encoder(
input_seq[start_idx + 1: end_idx + 1])
except:
print(start_idx)
print(end_idx)
print(has_zeros)

outputs.append(output)

Expand All @@ -139,12 +137,21 @@ def compute_contrastive_loss(self, rollouts, advantages):
# reverse back
input_seq = torch.flip(input_seq, [0])

return input_seq

def _encode_advantages(self, advantages):
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = advantages.shape
# ADVANTAGES
# encode
# n_steps x n_process x hidden_dim/2
advantages = self.advantage_encoder(
advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1)

return advantages

def _encode_states(self, rollouts):
num_steps, n_processes, _ = rollouts.actions.shape
# STATES
# encode
# n_steps x n_process x hidden_dim/2
Expand All @@ -154,29 +161,52 @@ def compute_contrastive_loss(self, rollouts, advantages):
states = self.state_encoder(
states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1)

return states

def _encode_actions(self, rollouts):
num_steps, n_processes, _ = rollouts.actions.shape
# ACTION
# encode
# n_steps x n_process x 1
actions = rollouts.actions
actions = self.action_encoder(
actions.reshape(-1)).reshape(num_steps, n_processes, -1)

# condition = STATE + ADVANTAGE
conditions = torch.cat([advantages, states, actions], dim=-1)
return actions

def compute_contrastive_loss(self, rollouts, encoded_advantages):
"""
Contrastive Predictive Coding for learning representation and density ratio
:param rollouts: Storage's instance
:param advantage: tensor of shape: (timestep, n_processes, 1)
"""
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = encoded_advantages.shape

# encoded all the input
encoded_input_seq = self._encode_input_sequences(rollouts)
encoded_advantages = self._encode_advantages(encoded_advantages)
encoded_states = self._encode_states(rollouts)
encoded_actions = self._encode_actions(rollouts)

# condition = STATE + ADVANTAGE + ACTIONS
conditions = torch.cat(
[encoded_advantages, encoded_states, encoded_actions], dim=-1)
# reshape to n_steps x hidden_dim x n_processes
conditions = conditions.permute(0, 2, 1)

# compute nce
contrastive_loss = 0
correct = 0
for i in range(num_steps):
density_ratio = torch.mm(input_seq[i], conditions[i])
# f(Z, s0, a0, R) WITHOUT exponential
f_value = torch.mm(encoded_input_seq[i], conditions[i])
# accuracy
correct += torch.sum(torch.eq(torch.argmax(self.softmax(
density_ratio), dim=1), torch.arange(0, n_processes).to(self.device)))
f_value), dim=1), torch.arange(0, n_processes).to(self.device)))
# nce
contrastive_loss += torch.sum(
torch.diag(self.log_softmax(density_ratio)))
torch.diag(self.log_softmax(f_value)))

# log loss
contrastive_loss /= -1*n_processes*num_steps
Expand All @@ -193,76 +223,14 @@ def compute_weighted_advantages(self, rollouts, advantages):
# FIXME: only compatible with 1D observation
num_steps, n_processes, _ = advantages.shape

# INPUT SEQUENCES AND MASKS
# the stochastic input will be defined by last 2 scalar
input_seq = rollouts.obs[1:, :, -2:]
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
# reverse the input seq order since we want to compute from right to left
input_seq = torch.flip(input_seq, [0])
masks = torch.flip(masks, [0])
# encode the input sequence
# Let's figure out which steps in the sequence have a zero for any agent
# We will always assume t=0 has a zero in it as that makes the logic cleaner
has_zeros = ((masks[1:] == 0.0)
.any(dim=-1)
.nonzero()
.squeeze()
.cpu())

# +1 to correct the masks[1:]
if has_zeros.dim() == 0:
# Deal with scalar
has_zeros = [has_zeros.item() + 1]
else:
has_zeros = (has_zeros + 1).numpy().tolist()

# add t=0 and t=T to the list
has_zeros = [-1] + has_zeros + [num_steps - 1]

outputs = []

for i in range(len(has_zeros) - 1):
# We can now process steps that don't have any zeros in masks together!
# This is much faster
start_idx = has_zeros[i]
end_idx = has_zeros[i + 1]

output, _ = self.input_seq_encoder(
input_seq[start_idx + 1: end_idx + 1])

outputs.append(output)

# x is a (T, N, -1) tensor
input_seq = torch.cat(outputs, dim=0)
assert len(input_seq) == num_steps
# reverse back
input_seq = torch.flip(input_seq, [0])

# ADVANTAGES
# encode
# n_steps x n_process x hidden_dim/2
encoded_advantages = self.advantage_encoder(
advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1)

# STATES
# encode
# n_steps x n_process x hidden_dim/2
states = rollouts.obs[:-1]
# FIXME: hard code for 1D env
states_shape = states.shape[2:][0]
states = self.state_encoder(
states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1)

# ACTION
# encode
# n_steps x n_process x 1
actions = rollouts.actions
actions = self.action_encoder(
actions.reshape(-1)).reshape(num_steps, n_processes, -1)
input_seq = self._encode_input_sequences(rollouts)
encoded_advantages = self._encode_advantages(advantages)
encoded_states = self._encode_states(rollouts)
encoded_actions = self._encode_actions(rollouts)

# condition = STATE + ADVANTAGE
conditions = torch.cat(
[encoded_advantages, states, actions], dim=-1)
[encoded_advantages, encoded_states, encoded_actions], dim=-1)
# reshape to n_steps x hidden_dim x n_processes
conditions = conditions.permute(0, 2, 1)

Expand Down

0 comments on commit 293ab82

Please sign in to comment.