Skip to content

Commit

Permalink
Merge pull request #488 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
Add all other class confidences to the result
  • Loading branch information
zhmiao authored May 3, 2024
2 parents 5f0f09f + 768c68d commit ef196f5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions PytorchWildlife/models/classification/resnet/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def results_generation(self, logits, img_ids, id_strip=None):
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
confidences = probs[0].tolist()
result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)]

results = []
for pred, img_id, conf in zip(preds, img_ids, confs):
r = {"img_id": str(img_id).strip(id_strip)}
r["prediction"] = self.CLASS_NAMES[pred.item()]
r["class_id"] = pred.item()
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)

return results
3 changes: 3 additions & 0 deletions PytorchWildlife/models/classification/resnet/serengeti.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@ def results_generation(self, logits, img_ids, id_strip=None):
probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
confidences = probs[0].tolist()
result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)]

results = []
for pred, img_id, conf in zip(preds, img_ids, confs):
r = {"img_id": str(img_id).strip(id_strip)}
r["prediction"] = self.CLASS_NAMES[pred.item()]
r["class_id"] = pred.item()
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)

return results

0 comments on commit ef196f5

Please sign in to comment.