-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathget_test_golden.py
80 lines (65 loc) · 2.08 KB
/
get_test_golden.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import json
import ijson
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--allMesh')
parser.add_argument('--test_set')
parser.add_argument('--completed_test')
args = parser.parse_args()
# Load all articles in file
f = open(args.allMesh, encoding="utf8")
objects = ijson.items(f, 'articles.item')
pmid = []
title = []
all_text = []
label = []
label_id = []
print('Start loading training data')
for obj in tqdm(objects):
try:
ids = obj["pmid"].strip()
heading = obj['title'].strip()
text = obj["abstractText"].strip()
original_label = obj["meshMajor"]
mesh_id = obj['meshId']
pmid.append(ids)
title.append(heading)
all_text.append(text)
label.append(original_label)
label_id.append(mesh_id)
except AttributeError:
print(obj["pmid"].strip())
# Load test set ids
f_t = open(args.test_set, encoding="utf8")
test_objects = ijson.items(f_t, 'documents.item')
test_pmid = []
print('Start loading test data')
for obj in tqdm(test_objects):
try:
ids = str(obj["pmid"]).strip()
test_pmid.append(ids)
except AttributeError:
print(obj["pmid"].strip())
# Create new test set with labels
print('Create new test set with labels')
dataset = []
for id in tqdm(test_pmid):
data_point = {}
if id in pmid:
data_point['pmid'] = id
idx = pmid.index(id)
data_point['title'] = title[idx]
data_point['abstract'] = all_text[idx]
data_point['meshMajor'] = label[idx]
data_point['meshId'] = label_id[idx]
dataset.append(data_point)
else:
print('Not in the list: ', id)
pubmed = {'documents': dataset}
print('write to files')
with open(args.completed_test, "w") as outfile:
json.dump(pubmed, outfile, indent=4)
if __name__ == "__main__":
main()