From 9c07925d8a885bfc9f45f9538287a395a688b1d8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Nov 2023 22:00:34 -0400 Subject: [PATCH 1/3] Update itertools requirement from 0.11.0 to 0.12.0 (#271) Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version. - [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md) - [Commits](https://github.com/rust-itertools/itertools/compare/v0.11.0...v0.12.0) --- updated-dependencies: - dependency-name: itertools dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 57445059..c13003b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ getrandom = { version = "0.2.8", optional = true } wasm-bindgen-test = "0.3" [dev-dependencies] -itertools = "0.11.0" +itertools = "0.12.0" serde_json = "1.0" bincode = "1.3.1" From 886b5631b7c4a8e2aac2dfa903d77e53195a564a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Meyer?= Date: Wed, 10 Jan 2024 19:59:10 +0100 Subject: [PATCH 2/3] In Naive Bayes, avoid using `Option::unwrap` and so avoid panicking from NaN values (#274) --- .../hyper_tuning/grid_search.rs | 4 +- src/naive_bayes/mod.rs | 94 +++++++++++++++++-- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/src/model_selection/hyper_tuning/grid_search.rs b/src/model_selection/hyper_tuning/grid_search.rs index 3c914e48..74242c60 100644 --- a/src/model_selection/hyper_tuning/grid_search.rs +++ b/src/model_selection/hyper_tuning/grid_search.rs @@ -3,9 +3,9 @@ use crate::{ api::{Predictor, SupervisedEstimator}, error::{Failed, FailedError}, - linalg::basic::arrays::{Array2, Array1}, - numbers::realnum::RealNumber, + linalg::basic::arrays::{Array1, Array2}, numbers::basenum::Number, + numbers::realnum::RealNumber, }; use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index e7ab7f6d..11614d14 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1}; use crate::numbers::basenum::Number; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; +use std::{cmp::Ordering, marker::PhantomData}; /// Distribution used in the Naive Bayes classifier. pub(crate) trait NBDistribution: Clone { @@ -92,11 +92,10 @@ impl, Y: Array1, D: NBDistribution Result { let y_classes = self.distribution.classes(); - let (rows, _) = x.shape(); - let predictions = (0..rows) - .map(|row_index| { - let row = x.get_row(row_index); - let (prediction, _probability) = y_classes + let predictions = x + .row_iter() + .map(|row| { + y_classes .iter() .enumerate() .map(|(class_index, class)| { @@ -106,11 +105,26 @@ impl, Y: Array1, D: NBDistribution ordering, + None => { + if p1.is_nan() { + Ordering::Less + } else if p2.is_nan() { + Ordering::Greater + } else { + Ordering::Equal + } + } + }) + .map(|(prediction, _probability)| *prediction) + .ok_or_else(|| Failed::predict("Failed to predict, there is no result")) }) - .collect::>(); + .collect::, Failed>>()?; let y_hat = Y::from_vec_slice(&predictions); Ok(y_hat) } @@ -119,3 +133,63 @@ pub mod bernoulli; pub mod categorical; pub mod gaussian; pub mod multinomial; + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::arrays::Array; + use crate::linalg::basic::matrix::DenseMatrix; + use num_traits::float::Float; + + type Model<'d> = BaseNaiveBayes, Vec, TestDistribution<'d>>; + + #[derive(Debug, PartialEq, Clone)] + struct TestDistribution<'d>(&'d Vec); + + impl<'d> NBDistribution for TestDistribution<'d> { + fn prior(&self, _class_index: usize) -> f64 { + 1. + } + + fn log_likelihood<'a>( + &'a self, + class_index: usize, + _j: &'a Box + 'a>, + ) -> f64 { + match self.0.get(class_index) { + &v @ 2 | &v @ 10 | &v @ 20 => v as f64, + _ => f64::nan(), + } + } + + fn classes(&self) -> &Vec { + &self.0 + } + } + + #[test] + fn test_predict() { + let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]); + + let val = vec![]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(_) => panic!("Should return error in case of empty classes"), + Err(err) => assert_eq!( + err.to_string(), + "Predict failed: Failed to predict, there is no result" + ), + } + + let val = vec![1, 2, 3]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(r) => assert_eq!(r, vec![2, 2, 2]), + Err(_) => panic!("Should success in normal case with NaNs"), + } + + let val = vec![20, 2, 10]; + match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) { + Ok(r) => assert_eq!(r, vec![20, 20, 20]), + Err(_) => panic!("Should success in normal case without NaNs"), + } + } +} From 4eadd16ce4af6d6e2364d6dac2fcb462ee18b47b Mon Sep 17 00:00:00 2001 From: Tushushu Date: Sun, 25 Feb 2024 12:37:30 +0800 Subject: [PATCH 3/3] Implement the feature importance for Decision Tree Classifier (#275) * store impurity in the node * add number of features * add a TODO * draft feature importance * feat * n_samples of node * compute_feature_importances * unit tests * always calculate impurity * fix bug * fix linter --- src/tree/decision_tree_classifier.rs | 108 +++++++++++++++++++++------ 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 4f36e5b9..3c9deaf7 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier< num_classes: usize, classes: Vec, depth: u16, + num_features: usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -159,11 +160,13 @@ pub enum SplitCriterion { #[derive(Debug, Clone)] struct Node { output: usize, + n_node_samples: usize, split_feature: usize, split_value: Option, split_score: Option, true_child: Option, false_child: Option, + impurity: Option, } impl, Y: Array1> PartialEq @@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters { } impl Node { - fn new(output: usize) -> Self { + fn new(output: usize, n_node_samples: usize) -> Self { Node { output, + n_node_samples, split_feature: 0, split_value: Option::None, split_score: Option::None, true_child: Option::None, false_child: Option::None, + impurity: Option::None, } } } @@ -507,6 +512,7 @@ impl, Y: Array1> num_classes: 0usize, classes: vec![], depth: 0u16, + num_features: 0usize, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -578,7 +584,7 @@ impl, Y: Array1> count[yi[i]] += samples[i]; } - let root = Node::new(which_max(&count)); + let root = Node::new(which_max(&count), y_ncols); change_nodes.push(root); let mut order: Vec> = Vec::new(); @@ -593,6 +599,7 @@ impl, Y: Array1> num_classes: k, classes, depth: 0u16, + num_features: num_attributes, _phantom_tx: PhantomData, _phantom_x: PhantomData, _phantom_y: PhantomData, @@ -678,16 +685,7 @@ impl, Y: Array1> } } - if is_pure { - return false; - } - let n = visitor.samples.iter().sum(); - - if n <= self.parameters().min_samples_split { - return false; - } - let mut count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes]; for i in 0..n_rows { @@ -696,7 +694,15 @@ impl, Y: Array1> } } - let parent_impurity = impurity(&self.parameters().criterion, &count, n); + self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); + + if is_pure { + return false; + } + + if n <= self.parameters().min_samples_split { + return false; + } let mut variables = (0..n_attr).collect::>(); @@ -705,14 +711,7 @@ impl, Y: Array1> } for variable in variables.iter().take(mtry) { - self.find_best_split( - visitor, - n, - &count, - &mut false_count, - parent_impurity, - *variable, - ); + self.find_best_split(visitor, n, &count, &mut false_count, *variable); } self.nodes()[visitor.node].split_score.is_some() @@ -724,7 +723,6 @@ impl, Y: Array1> n: usize, count: &[usize], false_count: &mut [usize], - parent_impurity: f64, j: usize, ) { let mut true_count = vec![0; self.num_classes]; @@ -760,6 +758,7 @@ impl, Y: Array1> let true_label = which_max(&true_count); let false_label = which_max(false_count); + let parent_impurity = self.nodes()[visitor.node].impurity.unwrap(); let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters().criterion, &true_count, tc) @@ -827,9 +826,9 @@ impl, Y: Array1> let true_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.true_child_output)); + self.nodes.push(Node::new(visitor.true_child_output, tc)); let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output)); + self.nodes.push(Node::new(visitor.false_child_output, fc)); self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx); @@ -863,6 +862,33 @@ impl, Y: Array1> true } + + /// Compute feature importances for the fitted tree. + pub fn compute_feature_importances(&self, normalize: bool) -> Vec { + let mut importances = vec![0f64; self.num_features]; + + for node in self.nodes().iter() { + if node.true_child.is_none() && node.false_child.is_none() { + continue; + } + let left = &self.nodes()[node.true_child.unwrap()]; + let right = &self.nodes()[node.false_child.unwrap()]; + + importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap() + - left.n_node_samples as f64 * left.impurity.unwrap() + - right.n_node_samples as f64 * right.impurity.unwrap(); + } + for item in importances.iter_mut() { + *item /= self.nodes()[0].n_node_samples as f64; + } + if normalize { + let sum = importances.iter().sum::(); + for importance in importances.iter_mut() { + *importance /= sum; + } + } + importances + } } #[cfg(test)] @@ -1016,6 +1042,42 @@ mod tests { ); } + #[test] + fn test_compute_feature_importances() { + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[1., 1., 1., 0.], + &[1., 1., 1., 0.], + &[1., 1., 1., 1.], + &[1., 1., 0., 0.], + &[1., 1., 0., 1.], + &[1., 0., 1., 0.], + &[1., 0., 1., 0.], + &[1., 0., 1., 1.], + &[1., 0., 0., 0.], + &[1., 0., 0., 1.], + &[0., 1., 1., 0.], + &[0., 1., 1., 0.], + &[0., 1., 1., 1.], + &[0., 1., 0., 0.], + &[0., 1., 0., 1.], + &[0., 0., 1., 0.], + &[0., 0., 1., 0.], + &[0., 0., 1., 1.], + &[0., 0., 0., 0.], + &[0., 0., 0., 1.], + ]); + let y: Vec = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + assert_eq!( + tree.compute_feature_importances(false), + vec![0., 0., 0.21333333333333332, 0.26666666666666666] + ); + assert_eq!( + tree.compute_feature_importances(true), + vec![0., 0., 0.4444444444444444, 0.5555555555555556] + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test