-
Notifications
You must be signed in to change notification settings - Fork 475
/
Copy pathprepare_pretrained_model.py
162 lines (136 loc) · 7.36 KB
/
prepare_pretrained_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
'''
This script prepares a pretrained model to be shared without exposing the data used for training.
'''
import glob
import os
import pickle
from pprint import pprint
import shutil
import utils
from entity_lstm import EntityLSTM
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from neuroner import utils_tf
from neuroner import neuromodel
def trim_dataset_pickle(input_dataset_filepath, output_dataset_filepath=None, delete_token_mappings=False):
'''
Remove the dataset and labels from dataset.pickle.
If delete_token_mappings = True, then also remove token_to_index and index_to_token except for UNK.
'''
print("Trimming dataset.pickle..")
if output_dataset_filepath == None:
output_dataset_filepath = os.path.join(os.path.dirname(input_dataset_filepath),
'dataset_trimmed.pickle')
dataset = pickle.load(open(input_dataset_filepath, 'rb'))
count = 0
print("Keys removed:")
keys_to_remove = ['character_indices', 'character_indices_padded', 'characters',
'label_indices', 'label_vector_indices', 'labels', 'token_indices',
'token_lengths', 'tokens', 'infrequent_token_indices', 'tokens_mapped_to_unk']
for key in keys_to_remove:
if key in dataset.__dict__:
del dataset.__dict__[key]
print('\t' + key)
count += 1
if delete_token_mappings:
dataset.__dict__['token_to_index'] = {dataset.__dict__['UNK']:dataset.__dict__['UNK_TOKEN_INDEX']}
dataset.__dict__['index_to_token'] = {dataset.__dict__['UNK_TOKEN_INDEX']:dataset.__dict__['UNK']}
print("Number of keys removed: {0}".format(count))
pprint(dataset.__dict__)
pickle.dump(dataset, open(output_dataset_filepath, 'wb'))
print("Done!")
def trim_model_checkpoint(parameters_filepath, dataset_filepath, input_checkpoint_filepath,
output_checkpoint_filepath):
'''
Remove all token embeddings except UNK.
'''
parameters, _ = neuromodel.load_parameters(parameters_filepath=parameters_filepath)
dataset = pickle.load(open(dataset_filepath, 'rb'))
model = EntityLSTM(dataset, parameters)
with tf.Session() as sess:
model_saver = tf.train.Saver() # defaults to saving all variables
# Restore the pretrained model
model_saver.restore(sess, input_checkpoint_filepath) # Works only when the dimensions of tensor variables are matched.
# Get pretrained embeddings
token_embedding_weights = sess.run(model.token_embedding_weights)
# Restore the sizes of token embedding weights
utils_tf.resize_tensor_variable(sess, model.token_embedding_weights,
[1, parameters['token_embedding_dimension']])
initial_weights = sess.run(model.token_embedding_weights)
initial_weights[dataset.UNK_TOKEN_INDEX] = token_embedding_weights[dataset.UNK_TOKEN_INDEX]
sess.run(tf.assign(model.token_embedding_weights, initial_weights, validate_shape=False))
token_embedding_weights = sess.run(model.token_embedding_weights)
print("token_embedding_weights: {0}".format(token_embedding_weights))
model_saver.save(sess, output_checkpoint_filepath)
dataset.__dict__['vocabulary_size'] = 1
pickle.dump(dataset, open(dataset_filepath, 'wb'))
pprint(dataset.__dict__)
def prepare_pretrained_model_for_restoring(output_folder_name, epoch_number,
model_name, delete_token_mappings=False):
'''
Copy the dataset.pickle, parameters.ini, and model checkpoint files after
removing the data used for training.
The dataset and labels are deleted from dataset.pickle by default. The only
information about the dataset that remain in the pretrained model
is the list of tokens that appears in the dataset and the corresponding token
embeddings learned from the dataset.
If delete_token_mappings is set to True, index_to_token and token_to_index
mappings are deleted from dataset.pickle additionally,
and the corresponding token embeddings are deleted from the model checkpoint
files. In this case, the pretrained model would not contain
any information about the dataset used for training the model.
If you wish to share a pretrained model with delete_token_mappings = True,
it is highly recommended to use some external pre-trained token
embeddings and freeze them while training the model to obtain high performance.
This can be done by specifying the token_pretrained_embedding_filepath
and setting freeze_token_embeddings = True in parameters.ini for training.
'''
input_model_folder = os.path.join('.', 'output', output_folder_name, 'model')
output_model_folder = os.path.join('.', 'trained_models', model_name)
utils.create_folder_if_not_exists(output_model_folder)
# trim and copy dataset.pickle
input_dataset_filepath = os.path.join(input_model_folder, 'dataset.pickle')
output_dataset_filepath = os.path.join(output_model_folder, 'dataset.pickle')
trim_dataset_pickle(input_dataset_filepath, output_dataset_filepath,
delete_token_mappings=delete_token_mappings)
# copy parameters.ini
parameters_filepath = os.path.join(input_model_folder, 'parameters.ini')
shutil.copy(parameters_filepath, output_model_folder)
# (trim and) copy checkpoint files
epoch_number_string = str(epoch_number).zfill(5)
if delete_token_mappings:
input_checkpoint_filepath = os.path.join(input_model_folder,
'model_{0}.ckpt'.format(epoch_number_string))
output_checkpoint_filepath = os.path.join(output_model_folder, 'model.ckpt')
trim_model_checkpoint(parameters_filepath, output_dataset_filepath,
input_checkpoint_filepath, output_checkpoint_filepath)
else:
for filepath in glob.glob(os.path.join(input_model_folder,
'model_{0}.ckpt*'.format(epoch_number_string))):
shutil.copyfile(filepath, os.path.join(output_model_folder,
os.path.basename(filepath).replace('_' + epoch_number_string, '')))
def check_contents_of_dataset_and_model_checkpoint(model_folder):
'''
Check the contents of dataset.pickle and model_xxx.ckpt.
model_folder: folder containing dataset.pickle and model_xxx.ckpt to be checked.
'''
dataset_filepath = os.path.join(model_folder, 'dataset.pickle')
dataset = pickle.load(open(dataset_filepath, 'rb'))
pprint(dataset.__dict__)
pprint(list(dataset.__dict__.keys()))
checkpoint_filepath = os.path.join(model_folder, 'model.ckpt')
with tf.Session() as sess:
print_tensors_in_checkpoint_file(checkpoint_filepath,
tensor_name='token_embedding/token_embedding_weights', all_tensors=True)
print_tensors_in_checkpoint_file(checkpoint_filepath,
tensor_name='token_embedding/token_embedding_weights', all_tensors=False)
if __name__ == '__main__':
output_folder_name = 'en_2017-05-05_08-58-32-633799'
epoch_number = 30
model_name = 'conll_2003_en'
delete_token_mappings = False
prepare_pretrained_model_for_restoring(output_folder_name, epoch_number,
model_name, delete_token_mappings)
# model_name = 'mimic_glove_spacy_iobes'
# model_folder = os.path.join('..', 'trained_models', model_name)
# check_contents_of_dataset_and_model_checkpoint(model_folder)