diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index f63cc2d9..5679516a 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -903,14 +903,14 @@ impl, Y: Array1> /// /// # Errors /// - /// Returns an error if the prediction process fails. + /// Returns an error if at least one row prediction process fails. pub fn predict_proba(&self, x: &X) -> Result, Failed> { let (n_samples, _) = x.shape(); let n_classes = self.classes().len(); let mut result = DenseMatrix::::zeros(n_samples, n_classes); for i in 0..n_samples { - let probs = self.predict_proba_for_row(x, i); + let probs = self.predict_proba_for_row(x, i)?; for (j, &prob) in probs.iter().enumerate() { result.set((i, j), prob); } @@ -930,7 +930,7 @@ impl, Y: Array1> /// /// A vector of probabilities, one for each class, representing the probability /// of the input sample belonging to each class. - fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec { + fn predict_proba_for_row(&self, x: &X, row: usize) -> Result, Failed> { let mut node = 0; while let Some(current_node) = self.nodes().get(node) { @@ -938,7 +938,7 @@ impl, Y: Array1> // Leaf node reached let mut probs = vec![0.0; self.classes().len()]; probs[current_node.output] = 1.0; - return probs; + return Ok(probs); } let split_feature = current_node.split_feature; @@ -952,7 +952,7 @@ impl, Y: Array1> } // This should never happen if the tree is properly constructed - vec![0.0; self.classes().len()] + Err(Failed::predict("Nodes iteration did not reach leaf")) } }