Skip to content

Commit

Permalink
add proper error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Mec-iS committed Jan 20, 2025
1 parent fc7f2e6 commit 5711788
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/tree/decision_tree_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
///
/// # Errors
///
/// Returns an error if the prediction process fails.
/// Returns an error if at least one row prediction process fails.
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
let (n_samples, _) = x.shape();
let n_classes = self.classes().len();
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);

for i in 0..n_samples {
let probs = self.predict_proba_for_row(x, i);
let probs = self.predict_proba_for_row(x, i)?;
for (j, &prob) in probs.iter().enumerate() {
result.set((i, j), prob);
}
Expand All @@ -930,15 +930,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
///
/// A vector of probabilities, one for each class, representing the probability
/// of the input sample belonging to each class.
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
let mut node = 0;

while let Some(current_node) = self.nodes().get(node) {
if current_node.true_child.is_none() && current_node.false_child.is_none() {
// Leaf node reached
let mut probs = vec![0.0; self.classes().len()];
probs[current_node.output] = 1.0;
return probs;
return Ok(probs);
}

let split_feature = current_node.split_feature;
Expand All @@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
}

// This should never happen if the tree is properly constructed
vec![0.0; self.classes().len()]
Err(Failed::predict("Nodes iteration did not reach leaf"))
}
}

Expand Down

0 comments on commit 5711788

Please sign in to comment.