Skip to content

Commit

Permalink
Merge pull request #132 from nicolewhite/tabular-args
Browse files Browse the repository at this point in the history
Allow feature_names to be None
  • Loading branch information
marcotcr authored Dec 22, 2017
2 parents 17a50c2 + a5fb3b0 commit 706647e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 3 additions & 9 deletions lime/lime_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,9 @@ def __init__(self,
"""
self.random_state = check_random_state(random_state)
self.mode = mode
self.feature_names = list(feature_names)
self.categorical_names = categorical_names
self.categorical_features = categorical_features
if self.categorical_names is None:
self.categorical_names = {}
if self.categorical_features is None:
self.categorical_features = []
if self.feature_names is None:
self.feature_names = [str(i) for i in range(training_data.shape[1])]
self.categorical_names = categorical_names or {}
self.categorical_features = categorical_features or []
self.feature_names = feature_names or [str(i) for i in range(training_data.shape[1])]

self.discretizer = None
if discretize_continuous:
Expand Down
5 changes: 5 additions & 0 deletions lime/tests/test_lime_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,11 @@ def test_lime_tabular_explainer_not_equal_random_state(self):

self.assertFalse(exp_1.as_map() != exp_2.as_map())

def testFeatureNames(self):
explainer = LimeTabularExplainer(training_data=np.array([[0., 1.], [1., 0.]]))

self.assertEqual(explainer.feature_names, ['0', '1'])


if __name__ == '__main__':
unittest.main()

0 comments on commit 706647e

Please sign in to comment.