diff --git a/core/algorithms/lacie/base_lacie.py b/core/algorithms/lacie/base_lacie.py index 145f1f7..6f4bf93 100644 --- a/core/algorithms/lacie/base_lacie.py +++ b/core/algorithms/lacie/base_lacie.py @@ -85,15 +85,8 @@ def __init__(self, self.softmax = nn.Softmax(dim=0) self.log_softmax = nn.LogSoftmax(dim=0) - def compute_contrastive_loss(self, rollouts, advantages): - """ - Contrastive Predictive Coding for learning representation and density ratio - :param rollouts: Storage's instance - :param advantage: tensor of shape: (timestep, n_processes, 1) - """ - # FIXME: only compatible with 1D observation - num_steps, n_processes, _ = advantages.shape - + def _encode_input_sequences(self, rollouts): + num_steps, n_processes, _ = rollouts.actions.shape # INPUT SEQUENCES AND MASKS # the stochastic input will be defined by last 2 scalar input_seq = rollouts.obs[1:, :, -2:] @@ -104,7 +97,7 @@ def compute_contrastive_loss(self, rollouts, advantages): # encode the input sequence # Let's figure out which steps in the sequence have a zero for any agent # We will always assume t=0 has a zero in it as that makes the logic cleaner - has_zeros = ((masks[1:] == 0.0) + has_zeros = ((masks[:-1] == 0.0) .any(dim=-1) .nonzero() .squeeze() @@ -128,8 +121,13 @@ def compute_contrastive_loss(self, rollouts, advantages): start_idx = has_zeros[i] end_idx = has_zeros[i + 1] - output, _ = self.input_seq_encoder( - input_seq[start_idx + 1: end_idx + 1]) + try: + output, _ = self.input_seq_encoder( + input_seq[start_idx + 1: end_idx + 1]) + except: + print(start_idx) + print(end_idx) + print(has_zeros) outputs.append(output) @@ -139,12 +137,21 @@ def compute_contrastive_loss(self, rollouts, advantages): # reverse back input_seq = torch.flip(input_seq, [0]) + return input_seq + + def _encode_advantages(self, advantages): + # FIXME: only compatible with 1D observation + num_steps, n_processes, _ = advantages.shape # ADVANTAGES # encode # n_steps x n_process x hidden_dim/2 advantages = self.advantage_encoder( advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1) + return advantages + + def _encode_states(self, rollouts): + num_steps, n_processes, _ = rollouts.actions.shape # STATES # encode # n_steps x n_process x hidden_dim/2 @@ -154,6 +161,10 @@ def compute_contrastive_loss(self, rollouts, advantages): states = self.state_encoder( states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1) + return states + + def _encode_actions(self, rollouts): + num_steps, n_processes, _ = rollouts.actions.shape # ACTION # encode # n_steps x n_process x 1 @@ -161,8 +172,26 @@ def compute_contrastive_loss(self, rollouts, advantages): actions = self.action_encoder( actions.reshape(-1)).reshape(num_steps, n_processes, -1) - # condition = STATE + ADVANTAGE - conditions = torch.cat([advantages, states, actions], dim=-1) + return actions + + def compute_contrastive_loss(self, rollouts, encoded_advantages): + """ + Contrastive Predictive Coding for learning representation and density ratio + :param rollouts: Storage's instance + :param advantage: tensor of shape: (timestep, n_processes, 1) + """ + # FIXME: only compatible with 1D observation + num_steps, n_processes, _ = encoded_advantages.shape + + # encoded all the input + encoded_input_seq = self._encode_input_sequences(rollouts) + encoded_advantages = self._encode_advantages(encoded_advantages) + encoded_states = self._encode_states(rollouts) + encoded_actions = self._encode_actions(rollouts) + + # condition = STATE + ADVANTAGE + ACTIONS + conditions = torch.cat( + [encoded_advantages, encoded_states, encoded_actions], dim=-1) # reshape to n_steps x hidden_dim x n_processes conditions = conditions.permute(0, 2, 1) @@ -170,13 +199,14 @@ def compute_contrastive_loss(self, rollouts, advantages): contrastive_loss = 0 correct = 0 for i in range(num_steps): - density_ratio = torch.mm(input_seq[i], conditions[i]) + # f(Z, s0, a0, R) WITHOUT exponential + f_value = torch.mm(encoded_input_seq[i], conditions[i]) # accuracy correct += torch.sum(torch.eq(torch.argmax(self.softmax( - density_ratio), dim=1), torch.arange(0, n_processes).to(self.device))) + f_value), dim=1), torch.arange(0, n_processes).to(self.device))) # nce contrastive_loss += torch.sum( - torch.diag(self.log_softmax(density_ratio))) + torch.diag(self.log_softmax(f_value))) # log loss contrastive_loss /= -1*n_processes*num_steps @@ -193,76 +223,14 @@ def compute_weighted_advantages(self, rollouts, advantages): # FIXME: only compatible with 1D observation num_steps, n_processes, _ = advantages.shape - # INPUT SEQUENCES AND MASKS - # the stochastic input will be defined by last 2 scalar - input_seq = rollouts.obs[1:, :, -2:] - masks = rollouts.masks[1:].reshape(num_steps, n_processes) - # reverse the input seq order since we want to compute from right to left - input_seq = torch.flip(input_seq, [0]) - masks = torch.flip(masks, [0]) - # encode the input sequence - # Let's figure out which steps in the sequence have a zero for any agent - # We will always assume t=0 has a zero in it as that makes the logic cleaner - has_zeros = ((masks[1:] == 0.0) - .any(dim=-1) - .nonzero() - .squeeze() - .cpu()) - - # +1 to correct the masks[1:] - if has_zeros.dim() == 0: - # Deal with scalar - has_zeros = [has_zeros.item() + 1] - else: - has_zeros = (has_zeros + 1).numpy().tolist() - - # add t=0 and t=T to the list - has_zeros = [-1] + has_zeros + [num_steps - 1] - - outputs = [] - - for i in range(len(has_zeros) - 1): - # We can now process steps that don't have any zeros in masks together! - # This is much faster - start_idx = has_zeros[i] - end_idx = has_zeros[i + 1] - - output, _ = self.input_seq_encoder( - input_seq[start_idx + 1: end_idx + 1]) - - outputs.append(output) - - # x is a (T, N, -1) tensor - input_seq = torch.cat(outputs, dim=0) - assert len(input_seq) == num_steps - # reverse back - input_seq = torch.flip(input_seq, [0]) - - # ADVANTAGES - # encode - # n_steps x n_process x hidden_dim/2 - encoded_advantages = self.advantage_encoder( - advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1) - - # STATES - # encode - # n_steps x n_process x hidden_dim/2 - states = rollouts.obs[:-1] - # FIXME: hard code for 1D env - states_shape = states.shape[2:][0] - states = self.state_encoder( - states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1) - - # ACTION - # encode - # n_steps x n_process x 1 - actions = rollouts.actions - actions = self.action_encoder( - actions.reshape(-1)).reshape(num_steps, n_processes, -1) + input_seq = self._encode_input_sequences(rollouts) + encoded_advantages = self._encode_advantages(advantages) + encoded_states = self._encode_states(rollouts) + encoded_actions = self._encode_actions(rollouts) # condition = STATE + ADVANTAGE conditions = torch.cat( - [encoded_advantages, states, actions], dim=-1) + [encoded_advantages, encoded_states, encoded_actions], dim=-1) # reshape to n_steps x hidden_dim x n_processes conditions = conditions.permute(0, 2, 1)