Skip to content

Commit

Permalink
Classifier: infer() returns empty array if unlearned
Browse files Browse the repository at this point in the history
instead of assertion, that makes processing pipelines easier.
  • Loading branch information
breznak committed Oct 8, 2019
1 parent 79d88da commit bb3808b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
10 changes: 5 additions & 5 deletions bindings/py/tests/algorithms/sdr_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

""" Unit tests for Classifier & Predictor classes. """

import math
import numpy
import pickle
import random
Expand Down Expand Up @@ -131,14 +132,13 @@ def testComputeInferOrLearnOnly(self):
inp.randomize( .3 )

# learn only
with self.assertRaises(RuntimeError):
c.infer(pattern=inp) # crash with not enough training data.
prediction = c.infer(pattern=inp)[1]
self.assertTrue(prediction == []) # not enough training data -> []
c.learn(recordNum=0, pattern=inp, classification=4)
with self.assertRaises(RuntimeError):
c.infer(pattern=inp) # crash with not enough training data.
self.assertTrue(c.infer(pattern=inp)[1] == []) # not enough training data.
c.learn(recordNum=2, pattern=inp, classification=4)
c.learn(recordNum=3, pattern=inp, classification=4)
c.infer(pattern=inp) # Don't crash with not enough training data.
self.assertTrue(c.infer(pattern=inp)[1] != []) # Don't crash with enough training data.

# infer only
retval1 = c.infer(pattern=inp)
Expand Down
22 changes: 10 additions & 12 deletions py/htm/examples/hotgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,16 @@ def main(parameters=default_parameters, argv=None, verbose=True):
tm_info.addData( tm.getActiveCells().flatten() )

# Predict what will happen, and then train the predictor based on what just happened.
if count > 5: #skip the n(=to the furthest predictions step) step, as predictor must learn something first
pdf = predictor.infer( tm.getActiveCells() )
for n in (1, 5):
if pdf[n]:
predictions[n].append( np.argmax( pdf[n] ) * predictor_resolution )
else:
predictions[n].append(float('nan'))

anomalyLikelihood = anomaly_history.anomalyProbability( consumption, tm.anomaly )
anomaly.append( tm.anomaly )
anomalyProb.append( anomalyLikelihood )
pdf = predictor.infer( tm.getActiveCells() )
for n in (1, 5):
if pdf[n]:
predictions[n].append( np.argmax( pdf[n] ) * predictor_resolution )
else:
predictions[n].append(float('nan'))

anomalyLikelihood = anomaly_history.anomalyProbability( consumption, tm.anomaly )
anomaly.append( tm.anomaly )
anomalyProb.append( anomalyLikelihood )

predictor.learn(count, tm.getActiveCells(), int(consumption / predictor_resolution))

Expand All @@ -192,7 +191,6 @@ def main(parameters=default_parameters, argv=None, verbose=True):
# Calculate the predictive accuracy, Root-Mean-Squared
accuracy = {1: 0, 5: 0}
accuracy_samples = {1: 0, 5: 0}
inputs = inputs[6:] #crop the first max prediction-steps inputs (as those don't have inferences)

for idx, inp in enumerate(inputs):
for n in predictions: # For each [N]umber of time steps ahead which was predicted.
Expand Down
6 changes: 4 additions & 2 deletions src/htm/algorithms/SDRClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ void Classifier::initialize(const Real alpha)
PDF Classifier::infer(const SDR & pattern) const {
// Check input dimensions, or if this is the first time the Classifier is used and dimensions
// are unset, return zeroes.
NTA_CHECK( dimensions_ != 0 )
<< "Classifier: must call `learn` before `infer`.";
if( dimensions_ == 0 ) {
NTA_WARN << "Classifier: must call `learn` before `infer`.";
return PDF(numCategories_, std::nan("")); //empty array []
}
NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!";

// Accumulate feed forward input.
Expand Down
1 change: 1 addition & 0 deletions src/htm/algorithms/SDRClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Classifier : public Serializable
* @param pattern: The SDR containing the active input bits.
* @returns: The Probablility Distribution Function (PDF) of the categories.
* This is indexed by the category label.
* Or empty array ([]) if Classifier hasn't called learn() before.
*/
PDF infer(const SDR & pattern) const;

Expand Down

0 comments on commit bb3808b

Please sign in to comment.