-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_weights.py
39 lines (32 loc) · 1.16 KB
/
load_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
# option2
# net = resnet34(num_classes=5)
# pre_weights = torch.load(model_weight_path, map_location=device)
# del_key = []
# for key, _ in pre_weights.items():
# if "fc" in key:
# del_key.append(key)
#
# for key in del_key:
# del pre_weights[key]
#
# missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
# print("[missing_keys]:", *missing_keys, sep="\n")
# print("[unexpected_keys]:", *unexpected_keys, sep="\n")
if __name__ == '__main__':
main()