-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
268 lines (237 loc) · 8.85 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
import csv
import time
import yaml
import argparse
import torch
from torch.utils.data import DataLoader
from model.nets import CQTNet
from model.dataset import TestDataset
from model.utils import load_model
from utilities.utils import format_time
from utilities.metrics import calculate_metrics
# My linux is complaining without the following line
# It required for the DataLoader to have the num_workers > 0
torch.multiprocessing.set_sharing_strategy("file_system")
@torch.no_grad()
def evaluate(
model: CQTNet,
loader: DataLoader,
similarity_search: str,
chunk_size: int,
noise_works: bool,
amp: bool,
device: torch.device,
) -> dict:
"""Evaluate the model by simulating the retrieval task. Compute the embeddings
of all versions and calculate the pairwise distances. Calculate the mean average
precision of the retrieval task. Metric calculations are done on the cpu but you
can choose the device for the model. Since we normalize the embeddings, MCSS is
equivalent to NNS. Please refer to the argparse arguments for more information.
Parameters:
-----------
model : CQTNet
Model to evaluate
loader : torch.utils.data.DataLoader
DataLoader containing the test set cliques
similarity_search: str
Similarity search function. MIPS, NNS, or MCSS.
chunk_size : int
Chunk size to use during metrics calculation.
noise_works : bool
Flag to indicate if the dataset contains noise works.
amp : bool
Flag to indicate if Automatic Mixed Precision should be used.
device : torch.device
Device to use for inference and metric calculation.
Returns:
--------
metrics : dict
Dictionary containing the evaluation metrics. See utilities.metrics.calculate_metrics
"""
t0 = time.monotonic()
model.eval()
N = len(loader)
emb_dim = model(
loader.dataset.__getitem__(0)[0].unsqueeze(0).unsqueeze(1).to(device)
).shape[1]
# Preallocate tensors to avoid https://github.com/pytorch/pytorch/issues/13246
embeddings = torch.zeros((N, emb_dim), dtype=torch.float32, device=device)
labels = torch.zeros(N, dtype=torch.int32, device=device)
print("Extracting embeddings...")
for idx, (feature, label) in enumerate(loader):
assert feature.shape[0] == 1, "Batch size must be 1 for inference."
feature = feature.unsqueeze(1).to(device) # (1,F,T) -> (1,1,F,T)
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=amp):
embedding = model(feature)
embeddings[idx : idx + 1] = embedding
labels[idx : idx + 1] = label.to(device)
if (idx + 1) % (len(loader) // 10) == 0 or idx == len(loader) - 1:
print(f"[{(idx+1):>{len(str(len(loader)))}}/{len(loader)}]")
print(f"Extraction time: {format_time(time.monotonic() - t0)}")
# If there are no noise works, remove the cliques with single versions
# this may happen due to the feature extraction process.
if not noise_works:
# Count each label's occurrence
unique_labels, counts = torch.unique(labels, return_counts=True)
# Filter labels that occur more than once
valid_labels = unique_labels[counts > 1]
# Create a mask for indices where labels appear more than once
keep_mask = torch.isin(labels, valid_labels)
if keep_mask.sum() < len(labels):
print("Removing single version cliques...")
embeddings = embeddings[keep_mask]
labels = labels[keep_mask]
print("Calculating metrics...")
t0 = time.monotonic()
metrics = calculate_metrics(
embeddings,
labels,
similarity_search=similarity_search,
noise_works=noise_works,
chunk_size=chunk_size,
device=device,
)
print(f"Calculation time: {format_time(time.monotonic() - t0)}")
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"config_path",
type=str,
help="""Path to the configuration file of the trained model.
The config will be used to find model weigths.""",
)
parser.add_argument(
"test_cliques",
type=str,
help="""Path to the test cliques.json file.
Can be SHS100K, Da-TACOS or DiscogsVI.""",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default=None,
help="Path to the output directory.",
)
parser.add_argument(
"--similarity-search",
"-s",
type=str,
default=None,
choices=["MIPS", "MCSS", "NNS"],
help="""Similarity search function to use for the evaluation.
MIPS: Maximum Inner Product Search,
MCSS: Maximum Cosine Similarity Search,
NNS: Nearest Neighbour Search.""",
)
parser.add_argument(
"--features-dir",
"-f",
type=str,
default=None,
help="""Path to the features directory.
Optional, by default uses the path in the config file.""",
)
parser.add_argument(
"--chunk-size",
"-b",
type=int,
default=1024,
help="Chunk size to use during metrics calculation.",
)
parser.add_argument(
"--num-workers",
type=int,
default=10,
help="Number of workers to use in the DataLoader.",
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="""Flag to disable the GPU. If not provided,
the GPU will be used if available.""",
)
parser.add_argument(
"--disable-amp",
action="store_true",
help="""Flag to disable Automatic Mixed Precision for inference.
If not provided, AMP usage will depend on the model config file.""",
)
args = parser.parse_args()
with open(args.config_path) as f:
config = yaml.safe_load(f)
print("\033[36m\nExperiment Configuration:\033[0m")
print(
"\033[36m" + yaml.dump(config, indent=4, width=120, sort_keys=False) + "\033[0m"
)
if args.features_dir is None:
print("\033[31mFeatures directory NOT provided.\033[0m")
args.features_dir = config["TRAIN"]["FEATURES_DIR"]
print(f"\033[31mFeatures directory: {args.features_dir}\033[0m\n")
# To evaluate the model in an Information Retrieval setting
eval_dataset = TestDataset(
args.test_cliques,
args.features_dir,
mean_downsample_factor=config["MODEL"]["DOWNSAMPLE_FACTOR"],
)
eval_loader = DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=args.num_workers,
)
if args.no_gpu:
device = torch.device("cpu")
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\033[31mDevice: {device}\033[0m\n")
if args.disable_amp:
config["TRAIN"]["AUTOMATIC_MIXED_PRECISION"] = False
model = load_model(config, device, mode="infer")
if args.output_dir is None:
script_dir = os.path.dirname(os.path.abspath(__file__))
args.output_dir = os.path.join(
script_dir, "logs", "evaluation", config["MODEL"]["NAME"]
)
if "best_epoch" in config["MODEL"]["CHECKPOINT_PATH"]:
args.output_dir = os.path.join(args.output_dir, "best_epoch")
elif "last_epoch" in config["MODEL"]["CHECKPOINT_PATH"]:
args.output_dir = os.path.join(args.output_dir, "last_epoch")
if eval_dataset.discogs_vi:
args.output_dir = os.path.join(args.output_dir, "DiscogsVI")
elif eval_dataset.datacos:
args.output_dir = os.path.join(args.output_dir, "Da-TACOS")
elif eval_dataset.shs100k:
args.output_dir = os.path.join(args.output_dir, "SHS100K")
else:
raise ValueError("Dataset not recognized.")
print(f"\033[31mOutput directory: {args.output_dir}\033[0m\n")
os.makedirs(args.output_dir, exist_ok=True)
if args.similarity_search is None:
args.similarity_search = config["MODEL"]["SIMILARITY_SEARCH"]
print("Evaluating...")
t0 = time.monotonic()
metrics = evaluate(
model,
eval_loader,
similarity_search=args.similarity_search,
chunk_size=args.chunk_size,
noise_works=eval_dataset.datacos,
amp=config["TRAIN"]["AUTOMATIC_MIXED_PRECISION"],
device=device,
)
print(f"Total time: {format_time(time.monotonic() - t0)}")
eval_path = os.path.join(args.output_dir, "evaluation_metrics.csv")
print(f"Saving the evaluation results in: {eval_path}")
with open(eval_path, "w") as f:
writer = csv.writer(f)
writer.writerow(["metric", "value"])
for metric, value in metrics.items():
writer.writerow([metric, value])
#############
print("Done!")