-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathTargetDiscriminator.py
55 lines (39 loc) · 1.67 KB
/
TargetDiscriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
from ImageTools import *
from GAN import Discriminator
class TargetDiscriminator:
def __init__(self, model_path, device='cuda'):
self.device = device
self.model = Discriminator().to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
self.transform = transforms.Resize((512, 512))
def predict(self, image_pcb):
# image = load_image_to_tensor(image_path)
image = convert_image_to_tensor(image_pcb)
image = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
prob = self.model(image).item()
return prob
def predict_batch(self, image_path_list):
images = []
for image_path in image_path_list:
image = load_image_to_tensor(image_path)
image = self.transform(image)
images.append(image)
images = torch.stack(images).to(self.device)
with torch.no_grad():
probs = self.model(images).cpu().numpy()
return probs
if __name__ == '__main__':
discriminator = TargetDiscriminator('saved_model/Discriminator_trained.pth')
# 读取图片
image_directory = 'dis_test/'
image_paths = [os.path.join(image_directory, filename) for filename in os.listdir(image_directory) if filename.endswith(('.jpg', '.jpeg', '.png'))]
# 批量图片预测
probs = discriminator.predict_batch(image_paths)
for path, prob in zip(image_paths, probs):
# 获取图片名称
image_name = os.path.basename(path)
print(f'The probability of {image_name} being a real target is: {prob:.4f}')