From 71fabcb831058db45ce2ef069b59a0158d1328e4 Mon Sep 17 00:00:00 2001 From: Nicole White Date: Mon, 18 Dec 2017 17:32:21 -0800 Subject: [PATCH 1/2] Allow feature_names to be None --- lime/lime_tabular.py | 10 +++------- lime/tests/test_lime_tabular.py | 5 +++++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py index 3c9405db..47761a93 100644 --- a/lime/lime_tabular.py +++ b/lime/lime_tabular.py @@ -146,15 +146,11 @@ def __init__(self, """ self.random_state = check_random_state(random_state) self.mode = mode - self.feature_names = list(feature_names) self.categorical_names = categorical_names self.categorical_features = categorical_features - if self.categorical_names is None: - self.categorical_names = {} - if self.categorical_features is None: - self.categorical_features = [] - if self.feature_names is None: - self.feature_names = [str(i) for i in range(training_data.shape[1])] + self.categorical_names = categorical_names or {} + self.categorical_features = categorical_features or [] + self.feature_names = feature_names or [str(i) for i in range(training_data.shape[1])] self.discretizer = None if discretize_continuous: diff --git a/lime/tests/test_lime_tabular.py b/lime/tests/test_lime_tabular.py index 99c77c28..5e5c5ea5 100644 --- a/lime/tests/test_lime_tabular.py +++ b/lime/tests/test_lime_tabular.py @@ -535,6 +535,11 @@ def test_lime_tabular_explainer_not_equal_random_state(self): self.assertFalse(exp_1.as_map() != exp_2.as_map()) + def testFeatureNames(self): + explainer = LimeTabularExplainer(training_data=np.array([[0., 1.], [1., 0.]])) + + self.assertEqual(explainer.feature_names, ['0', '1']) + if __name__ == '__main__': unittest.main() From a5fb3b0e47e88adc95c2c6fa1584957c1aab5977 Mon Sep 17 00:00:00 2001 From: Nicole White Date: Mon, 18 Dec 2017 18:05:55 -0800 Subject: [PATCH 2/2] Forgot to remove earlier assignment --- lime/lime_tabular.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py index 47761a93..6e21d1e0 100644 --- a/lime/lime_tabular.py +++ b/lime/lime_tabular.py @@ -146,8 +146,6 @@ def __init__(self, """ self.random_state = check_random_state(random_state) self.mode = mode - self.categorical_names = categorical_names - self.categorical_features = categorical_features self.categorical_names = categorical_names or {} self.categorical_features = categorical_features or [] self.feature_names = feature_names or [str(i) for i in range(training_data.shape[1])]