-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathrun_pytorch_server.py
103 lines (78 loc) · 2.86 KB
/
run_pytorch_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# encoding: utf-8
"""
@author: xyliao
@contact: [email protected]
"""
import io
import json
import flask
import torch
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms as T
from torchvision.models import resnet50
# Initialize our Flask application and the PyTorch model.
app = flask.Flask(__name__)
model = None
use_gpu = True
with open('imagenet_class.txt', 'r') as f:
idx2label = eval(f.read())
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
global model
model = resnet50(pretrained=True)
model.eval()
if use_gpu:
model.cuda()
def prepare_image(image, target_size):
"""Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
"""
if image.mode != 'RGB':
image = image.convert("RGB")
# Resize the input image nad preprocess it.
image = T.Resize(target_size)(image)
image = T.ToTensor()(image)
# Convert to Torch.Tensor and normalize.
image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# Add batch_size axis.
image = image[None]
if use_gpu:
image = image.cuda()
return torch.autograd.Variable(image, volatile=True)
@app.route("/predict", methods=["POST"])
def predict():
# Initialize the data dictionary that will be returned from the view.
data = {"success": False}
# Ensure an image was properly uploaded to our endpoint.
if flask.request.method == 'POST':
if flask.request.files.get("image"):
# Read the image in PIL format
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# Preprocess the image and prepare it for classification.
image = prepare_image(image, target_size=(224, 224))
# Classify the input image and then initialize the list of predictions to return to the client.
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
data['predictions'] = list()
# Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
label_name = idx2label[label]
r = {"label": label_name, "probability": float(prob)}
data['predictions'].append(r)
# Indicate that the request was a success.
data["success"] = True
# Return the data dictionary as a JSON response.
return flask.jsonify(data)
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model()
app.run()