-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsent_eval.py
130 lines (110 loc) · 3.39 KB
/
sent_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Copyright (C) eqtgroup.com Ltd 2021
https://github.com/EQTPartners/pause
License: MIT, https://github.com/EQTPartners/pause/LICENSE.md
"""
import os
import json
import argparse
import logging
import tensorflow as tf
import senteval
from tensorflow.core.example import example_pb2, feature_pb2
# The following import is mandatory
import tensorflow_text as text
import numpy as np
def prepare(params: senteval.utils.dotdict, samples: list) -> None:
"""Stub function required by SentEval"""
return
def batcher(params: senteval.utils.dotdict, batch: list) -> np.ndarray:
"""Transforms a batch of text sentences into sentence embeddings.
Args:
params (senteval.utils.dotdict): [description]
batch (list): [description]
Returns:
np.ndarray: [description]
"""
batch = [" ".join(sent) if sent != [] else "." for sent in batch]
embeddings = params["predict_fn"](
examples=tf.constant([make_example(sent) for sent in batch])
)["output"].numpy()
return embeddings
def make_example(text: str) -> tf.Tensor:
"""Make an example from plain string.
Args:
text (str): The input string.
Returns:
tf.Tensor: The serialized string example.
"""
ex = example_pb2.Example(
features=feature_pb2.Features(
feature={
"sentence": feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(value=[text.encode()])
),
}
)
)
return ex.SerializeToString()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_path",
type=str,
required=True,
help="Path to SentEval data",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="The trained embed model",
)
parser.add_argument(
"--model_location",
default="local",
type=str,
help="The model location: gcs or local",
)
args = parser.parse_args()
if "gcs" in str(args.model_location).lower():
model_dir = (
f"gs://motherbrain-pause/model/{args.model}/embed/serving_model_dir/"
)
else:
model_dir = f"./artifacts/model/embed/{args.model}/"
loaded_model = tf.saved_model.load(model_dir)
predict_fn = loaded_model.signatures["serving_default"]
params_senteval = {
"task_path": args.data_path,
"usepytorch": True,
"kfold": 10,
"classifier": {
"nhid": 0,
"optim": "rmsprop",
"batch_size": 128,
"tenacity": 3,
"epoch_size": 2,
},
"predict_fn": predict_fn,
}
logging.basicConfig(format="%(asctime)s : %(message)s", level=logging.DEBUG)
se = senteval.engine.SE(params_senteval, batcher, prepare)
transfer_tasks = [
"SST2",
"MR",
"CR",
"MPQA",
"SUBJ",
"TREC",
"MRPC",
]
results = se.eval(transfer_tasks)
print(args.model, results)
test_result_path = "./artifacts/test"
if not os.path.exists(test_result_path):
os.makedirs(test_result_path)
senteval_out_file = "{}/sent_eval_{}.txt".format(test_result_path, args.model)
with open(senteval_out_file, "w+") as out_file:
out_file.write(json.dumps(results))
print("The SentEval test result is exported to {}".format(senteval_out_file))