Skip to content

Commit

Permalink
fix ranger class mapping on predictions (#45)
Browse files Browse the repository at this point in the history
* fix ranger class mapping on predictions

* rename class order var

* upgrade poetry and fix test and docs

* rm post

* revert updates

* revert poetry update

* make docs build
  • Loading branch information
crflynn authored Oct 24, 2020
1 parent cf5ca9d commit 925bd9d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ qtconsole==4.7.5
qtpy==1.9.0
regex==2020.6.8
requests==2.24.0
scikit-learn==0.23.1
scikit-learn==0.22.0
scikit-survival==0.12.0
scipy==1.5.1
send2trash==1.5.0
Expand Down
7 changes: 5 additions & 2 deletions skranger/ensemble/ranger_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class specific values.
regularization factor input parameter.
:ivar int importance_mode\_: The importance mode integer corresponding to ranger
enum ``ImportanceMode``.
:ivar list ranger_class_order\_: The class reference ordering derived from ranger.
"""

def __init__(
Expand Down Expand Up @@ -156,7 +157,7 @@ def fit(self, X, y, sample_weight=None):
self._validate_parameters(X, y, sample_weight)

# Map classes to indices
y = y.copy()
y = np.copy(y)
self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = len(self.classes_)

Expand Down Expand Up @@ -217,6 +218,7 @@ def fit(self, X, y, sample_weight=None):
False, # use_regularization_factor
self.regularization_usedepth,
)
self.ranger_class_order_ = np.argsort(np.array(self.ranger_forest_["forest"]["class_values"]).astype(int))
return self

def predict(self, X):
Expand Down Expand Up @@ -282,7 +284,8 @@ def predict_proba(self, X):
self.use_regularization_factor_,
self.regularization_usedepth,
)
return np.atleast_2d(np.array(result["predictions"]))
predictions = np.atleast_2d(np.array(result["predictions"]))
return predictions[:, self.ranger_class_order_]

def predict_log_proba(self, X):
"""Predict log probabilities for classes from X.
Expand Down
23 changes: 23 additions & 0 deletions tests/ensemble/test_ranger_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np
import pytest
from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_is_fitted

from skranger.ensemble import RangerForestClassifier
Expand All @@ -24,6 +27,7 @@ def test_fit(self, iris_X, iris_y):
assert hasattr(rfc, "classes_")
assert hasattr(rfc, "n_classes_")
assert hasattr(rfc, "ranger_forest_")
assert hasattr(rfc, "ranger_class_order_")
assert hasattr(rfc, "n_features_")

def test_predict(self, iris_X, iris_y):
Expand Down Expand Up @@ -232,3 +236,22 @@ def test_always_split_features(self, iris_X, iris_y):
# feature 0 is in every tree split
for tree in rfc.ranger_forest_["forest"]["split_var_ids"]:
assert 0 in tree

def test_accuracy(self, iris_X, iris_y):
X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size=0.33, random_state=42)

# train and test a random forest classifier
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred_rf = rf.predict(X_test)
rf_acc = accuracy_score(y_test, y_pred_rf)

# train and test a ranger classifier
ra = RangerForestClassifier()
ra.fit(X_train, y_train)
y_pred_ra = ra.predict(X_test)
ranger_acc = accuracy_score(y_test, y_pred_ra)

# the accuracy should be good
assert rf_acc > 0.9
assert ranger_acc > 0.9

0 comments on commit 925bd9d

Please sign in to comment.