-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qualitative Results #7
Comments
Yes, it has not been included. I find one copy, but I am not sure if it is the used one (as it has been a long time, about 2 years). You may have to do some small updates if bugs happen. '''
For qualitative analysis and other analysis
'''
import argparse
import copy
import json
import os
from itertools import product
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from scipy.stats import hmean
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import cv2
from utils import *
from parameters import parser
from dataset import CompositionDataset
from model.model_factory import get_model
from test import predict_logits
cudnn.benchmark = True
device = "cuda" if torch.cuda.is_available() else "cpu"
if __name__ == "__main__":
config = parser.parse_args()
if config.yml_path:
load_args(config.yml_path, config)
# set the seed value
print("evaluation details")
print("----")
print(f"dataset: {config.dataset}")
dataset_path = config.dataset_path
with torch.no_grad():
print('loading test dataset')
test_dataset = CompositionDataset(dataset_path,
phase='test',
split='compositional-split-natural',
open_world=False) # change this if open-world
allattrs = test_dataset.attrs
allobj = test_dataset.objs
classes = [cla.replace(".", " ").lower() for cla in allobj]
attributes = [attr.replace(".", " ").lower() for attr in allattrs]
offset = len(attributes)
model = get_model(config, attributes=attributes, classes=classes, offset=offset).cuda()
model.load_state_dict(torch.load(config.load_model))
all_logits, all_attr_gt, all_obj_gt, all_pair_gt, loss_avg = predict_logits(
model, test_dataset, config)
pairs_dataset = test_dataset.pairs
print('all_logits.shape:', all_logits.shape)
prediction_list = []
prediction_idx = torch.max(all_logits, dim=-1)[1]
prediction_list = [pairs_dataset[idx.item()] for idx in prediction_idx]
# save
write_json(os.path.join(config.save_path, "top1_test_prediction.json"), prediction_list)
torch.save(all_logits, os.path.join(config.save_path, "test_prediction_logits.pt")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hello,author,I want to refer to the code of the "Qualitative Results" section, but I haven't found it.Is the code for the "Qualitative Results" section not made public?
The text was updated successfully, but these errors were encountered: