From cbe80fe75c8ddda39698cb0f7da07d0c389d82ff Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 22 Jan 2024 13:29:31 -0800 Subject: [PATCH] update output casting --- src/ecco/output.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ecco/output.py b/src/ecco/output.py index b42b9ad..a7a8c89 100644 --- a/src/ecco/output.py +++ b/src/ecco/output.py @@ -112,9 +112,7 @@ def __str__(self): return "".format(self.output_text, len(self._get_hidden_states()[1][-1])) def to(self, tensor: torch.Tensor): - if self.device == 'cuda': - return tensor.to('cuda') - return tensor + return tensor.to(self.device) def explorable(self, printJson: Optional[bool] = False): @@ -394,7 +392,7 @@ def layer_predictions(self, position: int = 1, topk: Optional[int] = 10, layer: layer_top_tokens = [self.tokenizer.decode(t) for t in sorted_softmax[-k:]][::-1] top_tokens.append(layer_top_tokens) - layer_probs = softmax[sorted_softmax[-k:]].cpu().detach().numpy()[::-1] + layer_probs = softmax[sorted_softmax[-k:]].float().cpu().detach().numpy()[::-1] probs.append(layer_probs.tolist()) # Package in output format