Skip to content

Commit

Permalink
Merge pull request #3 from luccaportes/six_import
Browse files Browse the repository at this point in the history
change on six import
  • Loading branch information
johny-c authored May 9, 2020
2 parents 856dd8d + 0d92dca commit ed66aaf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
8 changes: 7 additions & 1 deletion pylmnn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
from sklearn.utils.random import check_random_state
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted, check_array, check_X_y
from sklearn.externals.six import integer_types, string_types
from sklearn.exceptions import ConvergenceWarning
try:
from six import integer_types, string_types
except ImportError:
try:
from sklearn.externals.six import integer_types, string_types
except ImportError:
raise ImportError("The module six must be installed or the version of scikit-learn version must be < 0.23")

from .utils import _euclidean_distances_without_checks

Expand Down
10 changes: 9 additions & 1 deletion tests/test_lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import train_test_split
from sklearn.utils.extmath import row_norms
from sklearn.externals.six.moves import cStringIO as StringIO
from sklearn.exceptions import ConvergenceWarning

try:
from six.moves import cStringIO as StringIO
except ImportError:
try:
from sklearn.externals.six.moves import cStringIO as StringIO
except ImportError:
raise ImportError("The module six must be installed or the version of scikit-learn version must be < 0.23")


from pylmnn import LargeMarginNearestNeighbor
from pylmnn import make_lmnn_pipeline
from pylmnn.lmnn import _paired_distances_blockwise
Expand Down

0 comments on commit ed66aaf

Please sign in to comment.