-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjinus.py
46 lines (39 loc) · 1.87 KB
/
jinus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
from kivy.utils import platform
if platform == 'android':
from jnius import autoclass
File = autoclass('java.io.File')
Interpreter = autoclass('org.tensorflow.lite.Interpreter')
InterpreterOptions = autoclass('org.tensorflow.lite.Interpreter$Options')
Tensor = autoclass('org.tensorflow.lite.Tensor')
DataType = autoclass('org.tensorflow.lite.DataType')
TensorBuffer = autoclass(
'org.tensorflow.lite.support.tensorbuffer.TensorBuffer')
ByteBuffer = autoclass('java.nio.ByteBuffer')
class TensorFlowModel():
def load(self, model_filename, num_threads=None):
model = File(model_filename)
options = InterpreterOptions()
if num_threads is not None:
options.setNumThreads(num_threads)
self.interpreter = Interpreter(model, options)
self.allocate_tensors()
def allocate_tensors(self):
self.interpreter.allocateTensors()
self.input_shape = self.interpreter.getInputTensor(0).shape()
self.output_shape = self.interpreter.getOutputTensor(0).shape()
self.output_type = self.interpreter.getOutputTensor(0).dataType()
def get_input_shape(self):
return self.input_shape
def resize_input(self, shape):
if self.input_shape != shape:
self.interpreter.resizeInput(0, shape)
self.allocate_tensors()
def pred(self, x):
# assumes one input and one output for now
input = ByteBuffer.wrap(x.tobytes())
output = TensorBuffer.createFixedSize(self.output_shape,
self.output_type)
self.interpreter.run(input, output.getBuffer().rewind())
return np.reshape(np.array(output.getFloatArray()),
self.output_shape)