Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #245: return error for NaN in naive bayes #246

Open
wants to merge 7 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
clippy::approx_constant
)]
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]

//! # smartcore
//!
Expand Down
74 changes: 37 additions & 37 deletions src/naive_bayes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
use crate::numbers::basenum::Number;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::{cmp::Ordering, marker::PhantomData};
use std::marker::PhantomData;

/// Distribution used in the Naive Bayes classifier.
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
Expand Down Expand Up @@ -93,41 +93,41 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
/// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let y_classes = self.distribution.classes();
let predictions = x
.row_iter()
.map(|row| {
y_classes
.iter()
.enumerate()
.map(|(class_index, class)| {
(
class,
self.distribution.log_likelihood(class_index, &row)
+ self.distribution.prior(class_index).ln(),
)
})
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
// NaN must be considered as minimum values,
// therefore it's like NaNs would not be considered for choosing the maximum value.
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
Some(ordering) => ordering,
None => {
if p1.is_nan() {
Ordering::Less
} else if p2.is_nan() {
Ordering::Greater
} else {
Ordering::Equal
}
}
})
.map(|(prediction, _probability)| *prediction)
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
})
.collect::<Result<Vec<TY>, Failed>>()?;
let y_hat = Y::from_vec_slice(&predictions);
Ok(y_hat)

if y_classes.is_empty() {
return Err(Failed::predict("Failed to predict, no classes available"));
}

let (rows, _) = x.shape();
let mut predictions = Vec::with_capacity(rows);
let mut all_probs_nan = true;

for row_index in 0..rows {
let row = x.get_row(row_index);
let mut max_log_prob = f64::NEG_INFINITY;
let mut max_class = None;

for (class_index, class) in y_classes.iter().enumerate() {
let log_likelihood = self.distribution.log_likelihood(class_index, &row);
let log_prob = log_likelihood + self.distribution.prior(class_index).ln();

if !log_prob.is_nan() && log_prob > max_log_prob {
max_log_prob = log_prob;
max_class = Some(*class);
all_probs_nan = false;
}
}

predictions.push(max_class.unwrap_or(y_classes[0]));
}

if all_probs_nan {
Err(Failed::predict(
"Failed to predict, all probabilities were NaN",
))
} else {
Ok(Y::from_vec_slice(&predictions))
}
}
}
pub mod bernoulli;
Expand Down Expand Up @@ -177,7 +177,7 @@ mod tests {
Ok(_) => panic!("Should return error in case of empty classes"),
Err(err) => assert_eq!(
err.to_string(),
"Predict failed: Failed to predict, there is no result"
"Predict failed: Failed to predict, no classes available"
),
}

Expand Down
Loading