Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified predict to compute representations just once #105

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions lightfm/_lightfm_fast.pyx.template
Original file line number Diff line number Diff line change
Expand Up @@ -1139,51 +1139,129 @@ def fit_bpr(CSRMatrix item_features,
user_alpha)


def predict_lightfm(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] user_ids,
int[::1] item_ids,
double[::1] predictions,
FastLightFM lightfm,
int num_threads):
cdef precompute_unique(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] unique_users,
int[::1] unique_items,
flt *user_reprs,
flt *it_reprs,
FastLightFM lightfm,
int num_threads):
"""
Generate predictions.
Precomputes the representations for all the users in unique_users and
all the items in unique_items
"""

cdef int i, no_examples
cdef flt *user_repr
cdef int i, j
cdef flt *it_repr
cdef flt *user_repr
cdef int no_features
cdef int no_users

no_examples = predictions.shape[0]

no_features = unique_items.shape[0]
no_users = unique_users.shape[0]
{nogil_block}

user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))

for i in {range_block}(no_examples):

# users representations
for i in {range_block}(no_users):
compute_representation(user_features,
lightfm.user_features,
lightfm.user_biases,
lightfm,
user_ids[i],
unique_users[i],
lightfm.user_scale,
user_repr)
for j in {range_block}(lightfm.no_components + 1):
user_reprs[i * (lightfm.no_components + 1) + j] = user_repr[j]

# items representations
for i in {range_block}(no_features):
compute_representation(item_features,
lightfm.item_features,
lightfm.item_biases,
lightfm,
item_ids[i],
unique_items[i],
lightfm.item_scale,
it_repr)
for j in {range_block}(lightfm.no_components + 1):
it_reprs[i * (lightfm.no_components + 1) + j] = it_repr[j]


def predict_lightfm(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] user_ids,
int[::1] item_ids,
double[::1] predictions,
FastLightFM lightfm,
int num_threads,
bint precompute):
"""
Generate predictions.
"""
cdef int i, j, no_examples
cdef flt *user_repr
cdef flt *it_repr
cdef flt *user_reprs
cdef flt *it_reprs
cdef int[::1] unique_users
cdef int[::1] unique_items
cdef long[::1] inverse_users
cdef long[::1] inverse_items
cdef int no_features
cdef int no_users

no_examples = predictions.shape[0]

if precompute:
unique_users, inverse_users = np.unique(user_ids, return_inverse=True)
unique_items, inverse_items = np.unique(item_ids, return_inverse=True)
no_features = unique_items.shape[0]
no_users = unique_users.shape[0]

user_reprs = <flt *>malloc(sizeof(flt) * no_users * (lightfm.no_components + 1))
it_reprs = <flt *>malloc(sizeof(flt) * no_features *(lightfm.no_components + 1))
precompute_unique(item_features,
user_features,
unique_users,
unique_items,
user_reprs,
it_reprs,
lightfm,
num_threads)

{nogil_block}
user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
for i in {range_block}(no_examples):
if precompute:
for j in {range_block}(lightfm.no_components + 1):
user_repr[j] = user_reprs[inverse_users[i] * (lightfm.no_components + 1) + j]
it_repr[j] = it_reprs[inverse_items[i] * (lightfm.no_components + 1) + j]
else:
compute_representation(user_features,
lightfm.user_features,
lightfm.user_biases,
lightfm,
user_ids[i],
lightfm.user_scale,
user_repr)
compute_representation(item_features,
lightfm.item_features,
lightfm.item_biases,
lightfm,
item_ids[i],
lightfm.item_scale,
it_repr)

predictions[i] = compute_prediction_from_repr(user_repr,
it_repr,
lightfm.no_components)
it_repr,
lightfm.no_components)

free(user_repr)
free(it_repr)
if precompute:
free(user_reprs)
free(it_reprs)


def predict_ranks(CSRMatrix item_features,
Expand Down Expand Up @@ -1341,3 +1419,4 @@ def __test_in_positives(int row, int col, CSRMatrix mat):
return True
else:
return False

Loading