-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
324 lines (261 loc) · 10.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# -*- coding: utf-8 -*-
# @Author : HuangGang
# @Email : [email protected]
# @Time : 2021/11/10 12:00
# @Function : 工程框架工具类
import json
import os
import datetime
import random
import inspect
import sys
import pkgutil
import pprint as pp
import numpy as np
import config
import torch
import torch.backends.cudnn as cudnn
from loguru import logger as loguru_logger
from datetime import date
from pathlib import Path
from importlib import import_module
def count_param(model):
"""计算模型参数量
Args:
model ([type]): 模型的实例化对象
Returns:
[type]: 返回参数量
"""
param_count = 0
for param in model.parameters():
param_count += param.view(-1).size()[0]
# print('totoal parameters: %.2fMb (%d)'%(param/1e6,param))
return param_count
def fix_random_seed_as(random_seed=0):
"""
设置模型随机种子数
Args:
random_seed (int, optional): . Defaults to 0.
"""
random.seed(random_seed)
torch.manual_seed(random_seed) # CPU设置种子用于生成随机数,以使得结果是确定的
torch.cuda.manual_seed_all(random_seed) # GPU设置种子用于生成随机数,以使得结果是确定的
np.random.seed(random_seed)
# https://zhuanlan.zhihu.com/p/73711222
# 针对卷积的优化,benchmark-false使用cudnn默认的卷积优化算法
cudnn.deterministic = True
cudnn.benchmark = False
def setup_run(args):
set_up_gpu(args)
experiment_path = setup_experiment_folder(args)
pp.pprint({k: v for k, v in args.items() if v is not None}, width=1)
# pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)
return experiment_path
def setup_experiment_folder(args):
"""
实验数据存储文件夹设置方法
Args:
args ([type]): [description]
Returns:
str: 返回实验数据文件夹路径
"""
experiment_dir = os.path.join("experiments", args.experiment_dir)
experiment_description = args.experiment_description
if not os.path.exists(experiment_dir):
os.makedirs(experiment_dir)
experiment_path = os.path.join(experiment_dir, experiment_description)
if args.run_mode == "train":
if args.resume_training:
assert os.path.exists(experiment_path), "[!] RESUME 模型文件夹路径不存在,检查「args.experiment_description」!"
export_experiments_config_as_json(args, experiment_path, json_name="resume_config.json")
else:
experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)
os.mkdir(experiment_path)
export_experiments_config_as_json(args, experiment_path)
write_run_msg(experiment_path)
elif args.run_mode == "analyse":
assert os.path.exists(experiment_path), "[!] 未发现模型文件夹,检查「args.experiment_description」!"
else:
raise ValueError("[!]「args.run_mode」 模式错误! --可选:train、analyse")
loguru_logger.info("文件夹位置为: " + os.path.abspath(experiment_path))
return experiment_path
def write_run_msg(experiment_path):
pid = str(os.getpid())
f = open(os.path.join(experiment_path, "pid.txt"), "w")
f.write("该训练脚本的启动时间为 : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n")
f.write("该训练脚本的 PID 为 : " + str(pid))
f.close()
def get_name_of_experiment_path(experiment_dir, experiment_description):
experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today())))
idx = _get_experiment_index(experiment_path)
experiment_path = experiment_path + "_" + str(idx)
return experiment_path
def _get_experiment_index(experiment_path):
idx = 0
while os.path.exists(experiment_path + "_" + str(idx)):
idx += 1
return idx
def export_experiments_config_as_json(args, experiment_path, json_name="config.json"):
with open(os.path.join(experiment_path, json_name), "w") as outfile:
json.dump(vars(args), outfile, indent=4)
def all_subclasses(cls):
"""
返回所有的子类
Returns:
[class]: 父类类名
"""
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__() for s in all_subclasses(c)])
def import_all_subclasses(_file, _name, _class):
"""
导入所有的子类
Args:
_file ([type]): [description]
_name ([type]): [description]
_class ([type]): [description]
"""
modules = get_all_submodules(_file, _name)
for m in modules:
for i in dir(m):
attribute = getattr(m, i)
if inspect.isclass(attribute) and issubclass(attribute, _class):
setattr(sys.modules[_name], i, attribute)
def get_all_submodules(_file, _name):
modules = []
_dir = os.path.dirname(_file)
for _, name, ispkg in pkgutil.iter_modules([_dir]):
module = import_module("." + name, package=_name)
modules.append(module)
if ispkg:
modules.extend(get_all_submodules(module.__file__, module.__name__))
return modules
def set_up_gpu(args):
device_idx = str(args.device_idx)
os.environ["CUDA_VISIBLE_DEVICES"] = device_idx
args.num_gpu = len(device_idx.split(","))
args.device_idx = [int(i) for i in device_idx.split(",")]
def save_test_result(export_root, result):
"""
保存测试结果到文件
Args:
export_root (str): 保存路径
result (str): [description]
"""
filepath = Path(export_root).joinpath("test_result.txt")
with filepath.open("w") as f:
json.dump(result, f, indent=2)
def load_pretrained_weights(model, path):
"""
加载预训练模型权重
Args:
model (object): 模型类的实例化对象
path (str): 模型参数地址
"""
chk_dict = torch.load(os.path.abspath(path))
model_state_dict = chk_dict[config.STATE_DICT_KEY] if config.STATE_DICT_KEY in chk_dict else chk_dict["state_dict"]
model.load_state_dict(model_state_dict)
def setup_to_resume_from_recent(args, path):
"""
从最近的保存点继续训练
Args:
args ([type]): [description]
model ([type]): [description]
optimizer ([type]): [description]
"""
epoch_start = 0
accum_iter_start = 0
model = None
optimizer = None
lr_scheduler = None
chk_dict = torch.load(os.path.join(path, "models", config.RECENT_STATE_DICT_FILENAME))
if args.num_gpu > 1:
model.module.load_state_dict(chk_dict[config.STATE_DICT_KEY])
else:
model.load_state_dict(chk_dict[config.STATE_DICT_KEY])
if config.OPTIMIZER_STATE_DICT_KEY in chk_dict:
optimizer.load_state_dict(chk_dict[config.OPTIMIZER_STATE_DICT_KEY])
if config.SCHEDULER_STATE_DICT_KEY in chk_dict:
lr_scheduler.load_state_dict(chk_dict[config.SCHEDULER_STATE_DICT_KEY])
epoch_start, accum_iter_start = chk_dict[config.STEPS_DICT_KEY]
return epoch_start, accum_iter_start, model, optimizer, lr_scheduler
class AverageMeterSet(object):
def __init__(self, meters=None):
"""
训练指标数据处理类,负责各类指标数据操作, dict
Args:
object ([type]): [description]
meters (object, optional): [description]. Defaults to None.
"""
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter(key)
meter.update(0)
return meter
return self.meters[key]
def keys(self):
return self.meters.keys()
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter(name)
self.meters[name].update(value, n)
def reset(self):
if len(self.meters) != 0:
for meter in self.meters.values():
meter.reset()
"""
values() averages() sums() counts() 均返回字典:{name:meter.xxx}
"""
def values(self, format_string="{}"):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string="{}"):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string="{}"):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string="{}"):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
def get_meters(self, meter_type):
if meter_type == "sum":
return self.sums()
elif meter_type == "avg":
return self.averages()
elif meter_type == "value":
return self.values()
elif meter_type == "count":
return self.counts()
else:
raise ValueError('[!] meter_type error, choise ["sum", "avg", "val", "count"]')
class AverageMeter(object):
"""
计算并存储平均值 average 和当前值 current value
Args:
name (str): 数据的名称,Loss? Acc? or other
val (int): 一次 batch 计算的当前值
avg (int): 总 batch 计算的平均值
sum (int): 总 batch 计算的累加值
count (int, optional): batch 次的累计值. Defaults 1 个 batch.
fmt (str, optional): [description]. Defaults to ':f'.
"""
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def get_data(self, meter_type):
if meter_type == "sum":
return self.sum
elif meter_type == "avg":
return self.avg
elif meter_type == "value":
return self.val
elif meter_type == "count":
return self.count
else:
raise ValueError('[!] meter_type error, choise ["sum", "avg", "val", "count"]')