diff --git a/pylmnn/lmnn.py b/pylmnn/lmnn.py index 6c61b92..09d3d0a 100755 --- a/pylmnn/lmnn.py +++ b/pylmnn/lmnn.py @@ -24,8 +24,14 @@ from sklearn.utils.random import check_random_state from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import check_is_fitted, check_array, check_X_y -from sklearn.externals.six import integer_types, string_types from sklearn.exceptions import ConvergenceWarning +try: + from six import integer_types, string_types +except ImportError: + try: + from sklearn.externals.six import integer_types, string_types + except ImportError: + raise ImportError("The module six must be installed or the version of scikit-learn version must be < 0.23") from .utils import _euclidean_distances_without_checks diff --git a/tests/test_lmnn.py b/tests/test_lmnn.py index 831a760..29f5fae 100644 --- a/tests/test_lmnn.py +++ b/tests/test_lmnn.py @@ -17,9 +17,17 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.model_selection import train_test_split from sklearn.utils.extmath import row_norms -from sklearn.externals.six.moves import cStringIO as StringIO from sklearn.exceptions import ConvergenceWarning +try: + from six.moves import cStringIO as StringIO +except ImportError: + try: + from sklearn.externals.six.moves import cStringIO as StringIO + except ImportError: + raise ImportError("The module six must be installed or the version of scikit-learn version must be < 0.23") + + from pylmnn import LargeMarginNearestNeighbor from pylmnn import make_lmnn_pipeline from pylmnn.lmnn import _paired_distances_blockwise