Skip to content

Commit

Permalink
fixed kNN for kNNBasic on issue #131
Browse files Browse the repository at this point in the history
  • Loading branch information
ODemidenko committed Jan 30, 2018
1 parent fa85c0d commit 56311f0
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 103 deletions.
84 changes: 0 additions & 84 deletions surprise/prediction_algorithms/algo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class :class:`AlgoBase` from which every single prediction algorithm has to

from six import get_unbound_function as guf

from .. import similarities as sims
from .predictions import PredictionImpossible
from .predictions import Prediction
from .optimize_baselines import baseline_als
Expand All @@ -31,9 +30,6 @@ class AlgoBase(object):
def __init__(self, **kwargs):

self.bsl_options = kwargs.get('bsl_options', {})
self.sim_options = kwargs.get('sim_options', {})
if 'user_based' not in self.sim_options:
self.sim_options['user_based'] = True
self.skip_train = False

if (guf(self.__class__.fit) is guf(AlgoBase.fit) and
Expand Down Expand Up @@ -247,83 +243,3 @@ def compute_baselines(self):
raise ValueError('Invalid method ' + method_name +
' for baseline computation.' +
' Available methods are als and sgd.')

def compute_similarities(self):
"""Build the similarity matrix.
The way the similarity matrix is computed depends on the
``sim_options`` parameter passed at the creation of the algorithm (see
:ref:`similarity_measures_configuration`).
This method is only relevant for algorithms using a similarity measure,
such as the :ref:`k-NN algorithms <pred_package_knn_inpired>`.
Returns:
The similarity matrix."""

construction_func = {'cosine': sims.cosine,
'msd': sims.msd,
'pearson': sims.pearson,
'pearson_baseline': sims.pearson_baseline}

if self.sim_options['user_based']:
n_x, yr = self.trainset.n_users, self.trainset.ir
else:
n_x, yr = self.trainset.n_items, self.trainset.ur

min_support = self.sim_options.get('min_support', 1)

args = [n_x, yr, min_support]

name = self.sim_options.get('name', 'msd').lower()
if name == 'pearson_baseline':
shrinkage = self.sim_options.get('shrinkage', 100)
bu, bi = self.compute_baselines()
if self.sim_options['user_based']:
bx, by = bu, bi
else:
bx, by = bi, bu

args += [self.trainset.global_mean, bx, by, shrinkage]

try:
print('Computing the {0} similarity matrix...'.format(name))
sim = construction_func[name](*args)
print('Done computing similarity matrix.')
return sim
except KeyError:
raise NameError('Wrong sim name ' + name + '. Allowed values ' +
'are ' + ', '.join(construction_func.keys()) + '.')

def get_neighbors(self, iid, k):
"""Return the ``k`` nearest neighbors of ``iid``, which is the inner id
of a user or an item, depending on the ``user_based`` field of
``sim_options`` (see :ref:`similarity_measures_configuration`).
As the similarities are computed on the basis of a similarity measure,
this method is only relevant for algorithms using a similarity measure,
such as the :ref:`k-NN algorithms <pred_package_knn_inpired>`.
For a usage example, see the :ref:`FAQ <get_k_nearest_neighbors>`.
Args:
iid(int): The (inner) id of the user (or item) for which we want
the nearest neighbors. See :ref:`this note<raw_inner_note>`.
k(int): The number of neighbors to retrieve.
Returns:
The list of the ``k`` (inner) ids of the closest users (or items)
to ``iid``.
"""

if self.sim_options['user_based']:
all_instances = self.trainset.all_users
else:
all_instances = self.trainset.all_items

others = [(x, self.sim[iid, x]) for x in all_instances() if x != iid]
others.sort(key=lambda tple: tple[1], reverse=True)
k_nearest_neighbors = [j for (j, _) in others[:k]]

return k_nearest_neighbors
Loading

0 comments on commit 56311f0

Please sign in to comment.