-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrescore-trafo-lm.py
87 lines (76 loc) · 2.92 KB
/
rescore-trafo-lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
import os
import sys
import time
import tqdm
import torch
import logging
import speechbrain as sb
import itertools
from hyperpyyaml import load_hyperpyyaml
from types import SimpleNamespace
import pathlib
def setup(hparams, run_opts):
""" Kind of mimics what Brain does """
if "device" in run_opts:
device = run_opts["device"]
elif "device" in hparams:
device = hparams["device"]
else:
device = "cpu"
print("Device is:", device)
if "cuda" in device:
torch.cuda.set_device(int(device[-1]))
modules = torch.nn.ModuleDict(hparams["modules"]).to(device)
hparams = SimpleNamespace(**hparams)
if hasattr(hparams, "checkpointer"):
if hasattr(hparams, "test_max_key"):
ckpt = hparams.checkpointer.find_checkpoint(max_key=hparams.test_max_key)
elif hasattr(hparams, "test_min_key"):
ckpt = hparams.checkpointer.find_checkpoint(min_key=hparams.test_min_key)
else:
ckpt = hparams.checkpointer.find_checkpoint()
hparams.checkpointer.load_checkpoint(ckpt)
epoch = hparams.epoch_counter.current
print("Loaded checkpoint from epoch", epoch, "at path", ckpt.path)
modules.eval()
return modules, hparams, device
def count_lines(infile):
lines = 0
with open(infile) as fin:
for _ in fin:
lines += 1
return lines
def text_io(infile):
with open(infile) as fin:
for line in fin:
uttid, *text = line.strip().split()
yield uttid, text
def run_test(modules, hparams, device):
testfile = pathlib.Path(hparams.testfile)
num_utts = count_lines(testfile)
data_iter = text_io(testfile)
bosl = [hparams.bos_index]
eosl = [hparams.eos_index]
with open(hparams.test_out, 'w') as fo:
with torch.no_grad():
for uttid, text in tqdm.tqdm(data_iter, total=num_utts):
ids = [hparams.tokenizer.piece_to_id(piece) for piece in text]
encoded = torch.LongTensor(bosl + ids + eosl).to(device)
encoded = encoded.unsqueeze(0) # Fake a batch
tokens_bos = encoded[:,:-1]
tokens_eos = encoded[:,1:]
logits = hparams.model(tokens_bos)
pred = hparams.log_softmax(logits)
pred = pred.transpose(1, 2) # Shape the predictions for the NLL Loss
loss = torch.nn.functional.nll_loss(pred, tokens_eos, reduction="sum")
cost_item = loss.cpu().item()
print(uttid, cost_item, file=fo)
if __name__ == "__main__":
# Reading command line arguments
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# Load hyperparameters file with command-line overrides
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
modules, hparams, device = setup(hparams, run_opts)
run_test(modules, hparams, device)