Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C-EASE and ADD-EASE #696

Open
wants to merge 5 commits into
base: 0.2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from recbole.model.general_recommender.addease import ADDEASE
from recbole.model.general_recommender.bpr import BPR
from recbole.model.general_recommender.cease import CEASE
from recbole.model.general_recommender.convncf import ConvNCF
from recbole.model.general_recommender.dgcf import DGCF
from recbole.model.general_recommender.dmf import DMF
Expand Down
124 changes: 124 additions & 0 deletions recbole/model/general_recommender/addease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

r"""
ADD-EASE
################################################
Reference:
Olivier Jeunen, et al. "Closed-Form Models for Collaborative Filtering with Side-Information".

Reference code:
https://github.com/olivierjeunen/ease-side-info-recsys-2020/
"""


from recbole.utils.enum_type import ModelType, FeatureType
import numpy as np
import scipy.sparse as sp
import torch

from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender

from sklearn.preprocessing import MultiLabelBinarizer, OneHotEncoder


def encode_categorical_item_features(dataset, selected_features):
item_features = dataset.get_item_feature()

mlb = MultiLabelBinarizer(sparse_output=True)
ohe = OneHotEncoder(sparse=True)

encoded_feats = []

for feat in selected_features:
t = dataset.field2type[feat]
feat_frame = item_features[feat].numpy()

if t == FeatureType.TOKEN:
encoded = ohe.fit_transform(feat_frame.reshape(-1, 1))
encoded_feats.append(encoded)
elif t == FeatureType.TOKEN_SEQ:
encoded = mlb.fit_transform(feat_frame)

# drop first column which corresponds to the padding 0; real categories start at 1
# convert to csc first?
encoded = encoded[:, 1:]
encoded_feats.append(encoded)
else:
raise Warning(
f'ADD-EASE only supports token or token_seq types. [{feat}] is of type [{t}].')

if not encoded_feats:
raise ValueError(
f'No valid token or token_seq features to include.')

return sp.hstack(encoded_feats).T.astype(np.float32)


def ease_like(M, reg_weight):
# gram matrix
G = M.T @ M

# add reg to diagonal
G += reg_weight * sp.identity(G.shape[0])

# convert to dense because inverse will be dense
G = G.todense()

# invert. this takes most of the time
P = np.linalg.inv(G)
B = P / (-np.diag(P))
# zero out diag
np.fill_diagonal(B, 0.)

return B


class ADDEASE(GeneralRecommender):
input_type = InputType.POINTWISE
type = ModelType.TRADITIONAL

def __init__(self, config, dataset):
super().__init__(config, dataset)

# need at least one param
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

inter_matrix = dataset.inter_matrix(
form='csr').astype(np.float32)

item_feat_proportion = config['item_feat_proportion']
inter_reg_weight = config['inter_reg_weight']
item_reg_weight = config['item_reg_weight']
selected_features = config['selected_features']

tag_item_matrix = encode_categorical_item_features(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To sovle the above issue, you can clip the tag_item_matrix by num_items.
tag_item_matrix = tag_item_matrix[:, :self.num_items]

Or just filter the items by config parameter.
tem_inter_num_interval: [1,inf)

dataset, selected_features)

inter_S = ease_like(inter_matrix, inter_reg_weight)
item_S = ease_like(tag_item_matrix, item_reg_weight)

# instead of computing and storing the entire score matrix, just store B and compute the scores on demand
# more memory efficient for a larger number of users

# torch doesn't support sparse tensor slicing, so will do everything with np/scipy
self.item_similarity = (1-item_feat_proportion) * \
inter_S + item_feat_proportion * item_S
self.interaction_matrix = inter_matrix

def forward(self):
pass

def calculate_loss(self, interaction):
return torch.nn.Parameter(torch.zeros(1))

def predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()
item = interaction[self.ITEM_ID].cpu().numpy()

return torch.from_numpy((self.interaction_matrix[user, :].multiply(self.item_similarity[:, item].T)).sum(axis=1).getA1())

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()

r = self.interaction_matrix[user, :] @ self.item_similarity
return torch.from_numpy(r.flatten())
126 changes: 126 additions & 0 deletions recbole/model/general_recommender/cease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@

r"""
C-EASE
################################################
Reference:
Olivier Jeunen, et al. "Closed-Form Models for Collaborative Filtering with Side-Information".

Reference code:
https://github.com/olivierjeunen/ease-side-info-recsys-2020/
"""


from recbole.utils.enum_type import ModelType, FeatureType
import numpy as np
import scipy.sparse as sp
import torch

from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender

from sklearn.preprocessing import MultiLabelBinarizer, OneHotEncoder


def encode_categorical_item_features(dataset, selected_features):
item_features = dataset.get_item_feature()

mlb = MultiLabelBinarizer(sparse_output=True)
ohe = OneHotEncoder(sparse=True)

encoded_feats = []

for feat in selected_features:
t = dataset.field2type[feat]
feat_frame = item_features[feat].numpy()

if t == FeatureType.TOKEN:
encoded = ohe.fit_transform(feat_frame.reshape(-1, 1))
encoded_feats.append(encoded)
elif t == FeatureType.TOKEN_SEQ:
encoded = mlb.fit_transform(feat_frame)

# drop first column which corresponds to the padding 0; real categories start at 1
# convert to csc first?
encoded = encoded[:, 1:]
encoded_feats.append(encoded)
else:
raise Warning(
f'CEASE only supports token or token_seq types. [{feat}] is of type [{t}].')

if not encoded_feats:
raise ValueError(
f'No valid token or token_seq features to include.')

return sp.hstack(encoded_feats).T.astype(np.float32)


def ease_like(M, reg_weight):
# gram matrix
G = M.T @ M

# add reg to diagonal
G += reg_weight * sp.identity(G.shape[0])

# convert to dense because inverse will be dense
G = G.todense()

# invert. this takes most of the time
P = np.linalg.inv(G)
B = P / (-np.diag(P))
# zero out diag
np.fill_diagonal(B, 0.)

return B


class CEASE(GeneralRecommender):
input_type = InputType.POINTWISE
type = ModelType.TRADITIONAL

def __init__(self, config, dataset):
super().__init__(config, dataset)

# need at least one param
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

inter_matrix = dataset.inter_matrix(
form='csr').astype(np.float32)

item_feat_weight = config['item_feat_weight']
reg_weight = config['reg_weight']
selected_features = config['selected_features']

tag_item_matrix = item_feat_weight * \
encode_categorical_item_features(dataset, selected_features)

# just directly calculate the entire score matrix in init
# (can't be done incrementally)

X = sp.vstack([inter_matrix, tag_item_matrix]).tocsr()

item_similarity = ease_like(X, reg_weight)

# instead of computing and storing the entire score matrix, just store B and compute the scores on demand
# more memory efficient for a larger number of users

# torch doesn't support sparse tensor slicing, so will do everything with np/scipy
self.item_similarity = item_similarity
self.interaction_matrix = inter_matrix

def forward(self):
pass

def calculate_loss(self, interaction):
return torch.nn.Parameter(torch.zeros(1))

def predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()
item = interaction[self.ITEM_ID].cpu().numpy()

return torch.from_numpy((self.interaction_matrix[user, :].multiply(self.item_similarity[:, item].T)).sum(axis=1).getA1())

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()

r = self.interaction_matrix[user, :] @ self.item_similarity
return torch.from_numpy(r.flatten())
4 changes: 4 additions & 0 deletions recbole/properties/model/ADDEASE.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
item_feat_proportion: 0.001
inter_reg_weight: 350.0
item_reg_weight: 150.0
selected_features: ['class']
3 changes: 3 additions & 0 deletions recbole/properties/model/CEASE.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
item_feat_weight: 10.0
reg_weight: 350.0
selected_features: ['class']
8 changes: 8 additions & 0 deletions run_test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@
'model': 'MacridVAE',
'dataset': 'ml-100k',
},
'Test CEASE': {
'model': 'CEASE',
'dataset': 'ml-100k',
},
'Test ADDEASE': {
'model': 'ADDEASE',
'dataset': 'ml-100k',
},

# Context-aware Recommendation
'Test FM': {
Expand Down
12 changes: 12 additions & 0 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ def test_CDAE(self):
}
quick_test(config_dict)

def test_CEASE(self):
config_dict = {
'model': 'CEASE',
}
quick_test(config_dict)

def test_ADDEASE(self):
config_dict = {
'model': 'ADDEASE',
}
quick_test(config_dict)

class TestContextRecommender(unittest.TestCase):
# todo: more complex context information should be test, such as criteo dataset

Expand Down