-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
28 lines (21 loc) · 1004 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from collections import defaultdict
def get_top_n(predictions, n=10):
"""Return the top-N recommendation for each user from a set of predictions.
Args:
predictions(list of Prediction objects): The list of predictions, as
returned by the test method of an algorithm.
n(int): The number of recommendation to output for each user. Default
is 10.
Returns:
A dict where keys are user (raw) ids and values are lists of tuples:
[(raw item id, rating estimation), ...] of size n.
"""
# First map the predictions to each user.
top_n = defaultdict(list)
for user_id, item_id, true_r, est, _ in predictions:
top_n[user_id].append((item_id, est))
# Then sort the predictions for each user and retrieve the k highest ones.
for user_id, user_ratings in top_n.items():
user_ratings.sort(key=lambda x: x[1], reverse=True)
top_n[user_id] = user_ratings[:n]
return top_n