-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GBDTs feature importance #292
base: master
Are you sure you want to change the base?
Conversation
xnuohz
commented
Dec 12, 2023
- Add APIs to get feature importance
- Add test case
- Update example
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #292 +/- ##
=======================================
Coverage 93.41% 93.41%
=======================================
Files 116 116
Lines 5949 5970 +21
=======================================
+ Hits 5557 5577 +20
- Misses 392 393 +1 ☔ View full report in Codecov by Sentry. |
torch_frame/gbdt/tuned_xgboost.py
Outdated
@@ -232,3 +232,7 @@ def _load(self, path: str) -> None: | |||
import xgboost | |||
|
|||
self.model = xgboost.Booster(model_file=path) | |||
|
|||
def _feature_importance(self) -> list: | |||
scores = self.model.get_score(importance_type='weight') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe weight
can be passed as an argument
torch_frame/gbdt/tuned_lightgbm.py
Outdated
@@ -226,3 +226,7 @@ def _load(self, path: str) -> None: | |||
import lightgbm | |||
|
|||
self.model = lightgbm.Booster(model_file=path) | |||
|
|||
def _feature_importance(self) -> list: | |||
scores = self.model.feature_importance(importance_type='gain') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
torch_frame/gbdt/gbdt.py
Outdated
@@ -135,6 +139,19 @@ def load(self, path: str) -> None: | |||
self._load(path) | |||
self._is_fitted = True | |||
|
|||
def feature_importance(self) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def feature_importance(self) -> list: | |
def feature_importance(self, *args, **kwargs) -> list: |
test/gbdt/test_gbdt.py
Outdated
num_features = 0 | ||
for x in stypes: | ||
if x == stype.numerical: | ||
num_features += 3 * 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get the code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I want to get the total number of FakeDataset
features. 3 means the number and 1 means the dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do so by some more generic ways, e.g. getting the values and dimensions from col_names_dict
or tensor_frame
, rather than using magic numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. @weihua916 or @zechengz or @akihironitta should also take a look.
test/gbdt/test_gbdt.py
Outdated
num_features = 0 | ||
for x in stypes: | ||
if x == stype.numerical: | ||
num_features += 3 * 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do so by some more generic ways, e.g. getting the values and dimensions from col_names_dict
or tensor_frame
, rather than using magic numbers?
iteration (int, optional): Limit number of iterations in the feature | ||
importance calculation. If None, if the best iteration exists, | ||
it is used; otherwise, all trees are used. If <= 0, all trees | ||
are used (no limits). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add doc-string on iteration
examples/tuned_gbdt.py
Outdated
'feature': dataset.feat_cols, | ||
'importance': gbdt.feature_importance() | ||
}).sort_values(by='importance', ascending=False) | ||
print(scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some more text around the scores
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add an parser argument to enable user specify whether they want to have feature importance.
examples/tuned_gbdt.py
Outdated
'feature': dataset.feat_cols, | ||
'importance': gbdt.feature_importance() | ||
}).sort_values(by='importance', ascending=False) | ||
print(scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add an parser argument to enable user specify whether they want to have feature importance.
], f'Expect split or gain, got {importance_type}.' | ||
scores = self.model.feature_importance(importance_type=importance_type, | ||
iteration=iteration) | ||
return scores.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this list to be just a list of scores? IMO it's better to return a dictionary where keys are column names and values are corresponding scores. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return types of GBDT's feature importance API are different. For convenience, I converted them to lists.
lightgbm -> ndarray
xgboost -> dict[str, float]
catboost -> ndarray