Skip to content

Commit

Permalink
update output casting
Browse files Browse the repository at this point in the history
  • Loading branch information
SumanthRH committed Jan 22, 2024
1 parent dd8a7d7 commit cbe80fe
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/ecco/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def __str__(self):
return "<LMOutput '{}' # of lm outputs: {}>".format(self.output_text, len(self._get_hidden_states()[1][-1]))

def to(self, tensor: torch.Tensor):
if self.device == 'cuda':
return tensor.to('cuda')
return tensor
return tensor.to(self.device)

def explorable(self, printJson: Optional[bool] = False):

Expand Down Expand Up @@ -394,7 +392,7 @@ def layer_predictions(self, position: int = 1, topk: Optional[int] = 10, layer:

layer_top_tokens = [self.tokenizer.decode(t) for t in sorted_softmax[-k:]][::-1]
top_tokens.append(layer_top_tokens)
layer_probs = softmax[sorted_softmax[-k:]].cpu().detach().numpy()[::-1]
layer_probs = softmax[sorted_softmax[-k:]].float().cpu().detach().numpy()[::-1]
probs.append(layer_probs.tolist())

# Package in output format
Expand Down

0 comments on commit cbe80fe

Please sign in to comment.