diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index c157de47..c30eabc9 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1}; use crate::numbers::basenum::Number; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; +use std::{cmp::Ordering, marker::PhantomData}; /// Distribution used in the Naive Bayes classifier. pub(crate) trait NBDistribution: Clone { @@ -93,6 +93,7 @@ impl, Y: Array1, D: NBDistribution Result { let y_classes = self.distribution.classes(); + if y_classes.is_empty() { return Err(Failed::predict("Failed to predict, no classes available")); }