diff --git a/PytorchWildlife/models/classification/resnet/amazon.py b/PytorchWildlife/models/classification/resnet/amazon.py index c427fd518..c82a0f03c 100644 --- a/PytorchWildlife/models/classification/resnet/amazon.py +++ b/PytorchWildlife/models/classification/resnet/amazon.py @@ -93,6 +93,8 @@ def results_generation(self, logits, img_ids, id_strip=None): probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] + confidences = probs[0].tolist() + result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] results = [] for pred, img_id, conf in zip(preds, img_ids, confs): @@ -100,6 +102,7 @@ def results_generation(self, logits, img_ids, id_strip=None): r["prediction"] = self.CLASS_NAMES[pred.item()] r["class_id"] = pred.item() r["confidence"] = conf.item() + r["all_confidences"] = result results.append(r) return results diff --git a/PytorchWildlife/models/classification/resnet/serengeti.py b/PytorchWildlife/models/classification/resnet/serengeti.py index 3c32652e7..fa7671d51 100644 --- a/PytorchWildlife/models/classification/resnet/serengeti.py +++ b/PytorchWildlife/models/classification/resnet/serengeti.py @@ -67,6 +67,8 @@ def results_generation(self, logits, img_ids, id_strip=None): probs = torch.softmax(logits, dim=1) preds = probs.argmax(dim=1) confs = probs.max(dim=1)[0] + confidences = probs[0].tolist() + result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] results = [] for pred, img_id, conf in zip(preds, img_ids, confs): @@ -74,6 +76,7 @@ def results_generation(self, logits, img_ids, id_strip=None): r["prediction"] = self.CLASS_NAMES[pred.item()] r["class_id"] = pred.item() r["confidence"] = conf.item() + r["all_confidences"] = result results.append(r) return results