-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmy_test.py
80 lines (69 loc) · 3.34 KB
/
my_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from prepro import create_tokenizer_from_hub_module
from hparame import Hparame
hparame = Hparame()
parser = hparame.parser
hp = parser.parse_args()
set_training = True
def eval_decoded_texts(texts, predicted_labels, sentiment_ids, tokenizer):
decoded_texts = []
for i, text in enumerate(texts):
if type(text) == type(b""):
text = text.decode("utf-8")
# sentiment "neutral" or length < 2
if sentiment_ids[i] == 0 or len(text.split()) < 2:
decoded_texts.append(text)
else:
text_list = text.lower().split()
text_token = tokenizer.tokenize(text)
segment_id = []
# record the segment id
j_text = 0
j_token = 0
while j_text < len(text_list) and j_token < len(text_token):
_j_token = j_token + 1
text_a = "".join(tokenizer.tokenize(text_list[j_text])).replace("##", "")
while True:
segment_id.append(j_text)
if "".join(text_token[j_token:_j_token]).replace("##", "") == text_a:
j_token = _j_token
break
_j_token += 1
j_text += 1
assert len(segment_id) == len(text_token)
# get selected_text
selected_text = []
predicted_label_id = predicted_labels[i]
predicted_label_id.pop(0)
for _ in range(len(predicted_label_id) - len(text_token)):
predicted_label_id.pop()
max_len = len(predicted_label_id)
assert len(text_token) == max_len
j = 0
while j < max_len:
if predicted_label_id[j] == 1:
if j == max_len - 1:
j += 1
else:
a_selected_text = text_list[segment_id[j]]
selected_text.append(a_selected_text)
for new_j in range(j + 1, len(segment_id)):
if segment_id[j] != segment_id[new_j]:
j = new_j
break
elif new_j == len(segment_id) - 1:
j = new_j
else:
j += 1
decoded_texts.append(" ".join(selected_text))
return decoded_texts
if __name__ == "__main__":
texts = ['Well that�s disappointing to hear.', 'http://twitpic.com/4w75p - I like it!!']
predicted_labels = [[0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0] * 6 + [1] * 2 + [0] * 17]
sentiment_ids = [1, 0]
tokenizer = create_tokenizer_from_hub_module(hp)
output = eval_decoded_texts(texts, predicted_labels, sentiment_ids, tokenizer)
print(output)
# text_token = tokenizer.tokenize(texts[0])
# print(text_token[:3])
# print(texts[0].lower().split()[1])
# print("".join(text_token[1:3]).replace("##", "") == texts[0].lower().split()[1].decode("utf-8", "ignore"))