Skip to content

Commit

Permalink
Merge branch 'development' into fix-clippy-176
Browse files Browse the repository at this point in the history
  • Loading branch information
morenol authored Feb 25, 2024
2 parents 3e39859 + 4eadd16 commit a071d75
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ getrandom = { version = "0.2.8", optional = true }
wasm-bindgen-test = "0.3"

[dev-dependencies]
itertools = "0.11.0"
itertools = "0.12.0"
serde_json = "1.0"
bincode = "1.3.1"

Expand Down
4 changes: 2 additions & 2 deletions src/model_selection/hyper_tuning/grid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
use crate::{
api::{Predictor, SupervisedEstimator},
error::{Failed, FailedError},
linalg::basic::arrays::{Array2, Array1},
numbers::realnum::RealNumber,
linalg::basic::arrays::{Array1, Array2},
numbers::basenum::Number,
numbers::realnum::RealNumber,
};

use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
Expand Down
94 changes: 84 additions & 10 deletions src/naive_bayes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
use crate::numbers::basenum::Number;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use std::{cmp::Ordering, marker::PhantomData};

/// Distribution used in the Naive Bayes classifier.
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
Expand Down Expand Up @@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
/// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let y_classes = self.distribution.classes();
let (rows, _) = x.shape();
let predictions = (0..rows)
.map(|row_index| {
let row = x.get_row(row_index);
let (prediction, _probability) = y_classes
let predictions = x
.row_iter()
.map(|row| {
y_classes
.iter()
.enumerate()
.map(|(class_index, class)| {
Expand All @@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
+ self.distribution.prior(class_index).ln(),
)
})
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
.unwrap();
*prediction
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
// NaN must be considered as minimum values,
// therefore it's like NaNs would not be considered for choosing the maximum value.
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
Some(ordering) => ordering,
None => {
if p1.is_nan() {
Ordering::Less
} else if p2.is_nan() {
Ordering::Greater
} else {
Ordering::Equal
}
}
})
.map(|(prediction, _probability)| *prediction)
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
})
.collect::<Vec<TY>>();
.collect::<Result<Vec<TY>, Failed>>()?;
let y_hat = Y::from_vec_slice(&predictions);
Ok(y_hat)
}
Expand All @@ -119,3 +133,63 @@ pub mod bernoulli;
pub mod categorical;
pub mod gaussian;
pub mod multinomial;

#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use num_traits::float::Float;

type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;

#[derive(Debug, PartialEq, Clone)]
struct TestDistribution<'d>(&'d Vec<i32>);

impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
fn prior(&self, _class_index: usize) -> f64 {
1.
}

fn log_likelihood<'a>(
&'a self,
class_index: usize,
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
) -> f64 {
match self.0.get(class_index) {
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
_ => f64::nan(),
}
}

fn classes(&self) -> &Vec<i32> {
&self.0
}
}

#[test]
fn test_predict() {
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);

let val = vec![];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(_) => panic!("Should return error in case of empty classes"),
Err(err) => assert_eq!(
err.to_string(),
"Predict failed: Failed to predict, there is no result"
),
}

let val = vec![1, 2, 3];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
Err(_) => panic!("Should success in normal case with NaNs"),
}

let val = vec![20, 2, 10];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
Err(_) => panic!("Should success in normal case without NaNs"),
}
}
}
108 changes: 85 additions & 23 deletions src/tree/decision_tree_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
num_classes: usize,
classes: Vec<TY>,
depth: u16,
num_features: usize,
_phantom_tx: PhantomData<TX>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
Expand Down Expand Up @@ -159,11 +160,13 @@ pub enum SplitCriterion {
#[derive(Debug, Clone)]
struct Node {
output: usize,
n_node_samples: usize,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
impurity: Option<f64>,
}

impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
Expand Down Expand Up @@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
}

impl Node {
fn new(output: usize) -> Self {
fn new(output: usize, n_node_samples: usize) -> Self {
Node {
output,
n_node_samples,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
impurity: Option::None,
}
}
}
Expand Down Expand Up @@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: 0usize,
classes: vec![],
depth: 0u16,
num_features: 0usize,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
Expand Down Expand Up @@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
count[yi[i]] += samples[i];
}

let root = Node::new(which_max(&count));
let root = Node::new(which_max(&count), y_ncols);
change_nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();

Expand All @@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: k,
classes,
depth: 0u16,
num_features: num_attributes,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
Expand Down Expand Up @@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}
}

if is_pure {
return false;
}

let n = visitor.samples.iter().sum();

if n <= self.parameters().min_samples_split {
return false;
}

let mut count = vec![0; self.num_classes];
let mut false_count = vec![0; self.num_classes];
for i in 0..n_rows {
Expand All @@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}
}

let parent_impurity = impurity(&self.parameters().criterion, &count, n);
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n));

if is_pure {
return false;
}

if n <= self.parameters().min_samples_split {
return false;
}

let mut variables = (0..n_attr).collect::<Vec<_>>();

Expand All @@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}

for variable in variables.iter().take(mtry) {
self.find_best_split(
visitor,
n,
&count,
&mut false_count,
parent_impurity,
*variable,
);
self.find_best_split(visitor, n, &count, &mut false_count, *variable);
}

self.nodes()[visitor.node].split_score.is_some()
Expand All @@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
n: usize,
count: &[usize],
false_count: &mut [usize],
parent_impurity: f64,
j: usize,
) {
let mut true_count = vec![0; self.num_classes];
Expand Down Expand Up @@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

let true_label = which_max(&true_count);
let false_label = which_max(false_count);
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
let gain = parent_impurity
- tc as f64 / n as f64
* impurity(&self.parameters().criterion, &true_count, tc)
Expand Down Expand Up @@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

let true_child_idx = self.nodes().len();

self.nodes.push(Node::new(visitor.true_child_output));
self.nodes.push(Node::new(visitor.true_child_output, tc));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes.push(Node::new(visitor.false_child_output, fc));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);

Expand Down Expand Up @@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>

true
}

/// Compute feature importances for the fitted tree.
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
let mut importances = vec![0f64; self.num_features];

for node in self.nodes().iter() {
if node.true_child.is_none() && node.false_child.is_none() {
continue;
}
let left = &self.nodes()[node.true_child.unwrap()];
let right = &self.nodes()[node.false_child.unwrap()];

importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
- left.n_node_samples as f64 * left.impurity.unwrap()
- right.n_node_samples as f64 * right.impurity.unwrap();
}
for item in importances.iter_mut() {
*item /= self.nodes()[0].n_node_samples as f64;
}
if normalize {
let sum = importances.iter().sum::<f64>();
for importance in importances.iter_mut() {
*importance /= sum;
}
}
importances
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1016,6 +1042,42 @@ mod tests {
);
}

#[test]
fn test_compute_feature_importances() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.],
&[1., 1., 1., 0.],
&[1., 1., 1., 1.],
&[1., 1., 0., 0.],
&[1., 1., 0., 1.],
&[1., 0., 1., 0.],
&[1., 0., 1., 0.],
&[1., 0., 1., 1.],
&[1., 0., 0., 0.],
&[1., 0., 0., 1.],
&[0., 1., 1., 0.],
&[0., 1., 1., 0.],
&[0., 1., 1., 1.],
&[0., 1., 0., 0.],
&[0., 1., 0., 1.],
&[0., 0., 1., 0.],
&[0., 0., 1., 0.],
&[0., 0., 1., 1.],
&[0., 0., 0., 0.],
&[0., 0., 0., 1.],
]);
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
assert_eq!(
tree.compute_feature_importances(false),
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
);
assert_eq!(
tree.compute_feature_importances(true),
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
);
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down

0 comments on commit a071d75

Please sign in to comment.