-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrecognize_commands.h
156 lines (137 loc) · 5.95 KB
/
recognize_commands.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#include <cstdint>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/micro_model_settings.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
// Partial implementation of std::dequeue, just providing the functionality
// that's needed to keep a record of previous neural network results over a
// short time period, so they can be averaged together to produce a more
// accurate overall prediction. This doesn't use any dynamic memory allocation
// so it's a better fit for microcontroller applications, but this does mean
// there are hard limits on the number of results it can store.
class PreviousResultsQueue {
public:
PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
: error_reporter_(error_reporter), front_index_(0), size_(0) {}
// Data structure that holds an inference result, and the time when it
// was recorded.
struct Result {
Result() : time_(0), scores_() {}
Result(int32_t time, uint8_t* scores) : time_(time) {
for (int i = 0; i < kCategoryCount; ++i) {
scores_[i] = scores[i];
}
}
int32_t time_;
uint8_t scores_[kCategoryCount];
};
int size() { return size_; }
bool empty() { return size_ == 0; }
Result& front() { return results_[front_index_]; }
Result& back() {
int back_index = front_index_ + (size_ - 1);
if (back_index >= kMaxResults) {
back_index -= kMaxResults;
}
return results_[back_index];
}
void push_back(const Result& entry) {
if (size() >= kMaxResults) {
error_reporter_->Report(
"Couldn't push_back latest result, too many already!");
return;
}
size_ += 1;
back() = entry;
}
Result pop_front() {
if (size() <= 0) {
error_reporter_->Report("Couldn't pop_front result, none present!");
return Result();
}
Result result = front();
front_index_ += 1;
if (front_index_ >= kMaxResults) {
front_index_ = 0;
}
size_ -= 1;
return result;
}
// Most of the functions are duplicates of dequeue containers, but this
// is a helper that makes it easy to iterate through the contents of the
// queue.
Result& from_front(int offset) {
if ((offset < 0) || (offset >= size_)) {
error_reporter_->Report("Attempt to read beyond the end of the queue!");
offset = size_ - 1;
}
int index = front_index_ + offset;
if (index >= kMaxResults) {
index -= kMaxResults;
}
return results_[index];
}
private:
tflite::ErrorReporter* error_reporter_;
static constexpr int kMaxResults = 50;
Result results_[kMaxResults];
int front_index_;
int size_;
};
// This class is designed to apply a very primitive decoding model on top of the
// instantaneous results from running an audio recognition model on a single
// window of samples. It applies smoothing over time so that noisy individual
// label scores are averaged, increasing the confidence that apparent matches
// are real.
// To use it, you should create a class object with the configuration you
// want, and then feed results from running a TensorFlow model into the
// processing method. The timestamp for each subsequent call should be
// increasing from the previous, since the class is designed to process a stream
// of data over time.
class RecognizeCommands {
public:
// labels should be a list of the strings associated with each one-hot score.
// The window duration controls the smoothing. Longer durations will give a
// higher confidence that the results are correct, but may miss some commands.
// The detection threshold has a similar effect, with high values increasing
// the precision at the cost of recall. The minimum count controls how many
// results need to be in the averaging window before it's seen as a reliable
// average. This prevents erroneous results when the averaging window is
// initially being populated for example. The suppression argument disables
// further recognitions for a set time after one has been triggered, which can
// help reduce spurious recognitions.
explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
int32_t average_window_duration_ms = 1000,
uint8_t detection_threshold = 200,
int32_t suppression_ms = 1500,
int32_t minimum_count = 3);
// Call this with the results of running a model on sample data.
TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
const int32_t current_time_ms,
const char** found_command, uint8_t* score,
bool* is_new_command);
private:
// Configuration
tflite::ErrorReporter* error_reporter_;
int32_t average_window_duration_ms_;
uint8_t detection_threshold_;
int32_t suppression_ms_;
int32_t minimum_count_;
// Working variables
PreviousResultsQueue previous_results_;
const char* previous_top_label_;
int32_t previous_top_label_time_;
};
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_