Skip to content

Commit 93ca437

Browse files
authored
add comirec model
add comirec model
2 parents 5dab795 + b681eb8 commit 93ca437

25 files changed

+786
-48
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Steps to reproduce the behavior:
2020
**Operating environment(运行环境):**
2121
- python version [e.g. 3.6, 3.7, 3.8]
2222
- tensorflow version [e.g. 1.9.0, 1.14.0, 2.5.0]
23-
- deepmatch version [e.g. 0.3.0,]
23+
- deepmatch version [e.g. 0.3.1,]
2424

2525
**Additional context**
2626
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/question.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ Add any other context about the problem here.
1717
**Operating environment(运行环境):**
1818
- python version [e.g. 3.6, 3.7, 3.8]
1919
- tensorflow version [e.g. 1.9.0, 1.14.0, 2.5.0]
20-
- deepmatch version [e.g. 0.3.0,]
20+
- deepmatch version [e.g. 0.3.1,]

.github/workflows/ci.yml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
timeout-minutes: 120
1818
strategy:
1919
matrix:
20-
python-version: [3.6,3.7,3.8]
20+
python-version: [3.6,3.7,3.8,3.9,3.10.7]
2121
tf-version: [1.9.0,1.14.0,2.5.0]
2222

2323
exclude:
@@ -57,12 +57,28 @@ jobs:
5757
tf-version: 2.8.0
5858
- python-version: 3.6
5959
tf-version: 2.9.0
60+
- python-version: 3.6
61+
tf-version: 2.10.0
6062
- python-version: 3.9
6163
tf-version: 1.4.0
64+
- python-version: 3.9
65+
tf-version: 1.9.0
6266
- python-version: 3.9
6367
tf-version: 1.15.0
6468
- python-version: 3.9
65-
tf-version: 2.2.0
69+
tf-version: 1.14.0
70+
- python-version: 3.10.7
71+
tf-version: 1.4.0
72+
- python-version: 3.10.7
73+
tf-version: 1.9.0
74+
- python-version: 3.10.7
75+
tf-version: 1.15.0
76+
- python-version: 3.10.7
77+
tf-version: 1.14.0
78+
- python-version: 3.10.7
79+
tf-version: 2.5.0
80+
- python-version: 3.10.7
81+
tf-version: 2.6.0
6682

6783
steps:
6884

@@ -75,6 +91,7 @@ jobs:
7591

7692
- name: Install dependencies
7793
run: |
94+
sudo apt update && sudo apt install -y pkg-config libhdf5-dev
7895
pip3 install -q tensorflow==${{ matrix.tf-version }}
7996
pip install -q protobuf==3.19.0
8097
pip install -q requests

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.idea
12
# Byte-compiled / optimized / DLL files
23
__pycache__/
34
*.py[cod]

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.
3232
| NCF | [WWW 2017][Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
3333
| SDM | [CIKM 2019][SDM: Sequential Deep Matching Model for Online Large-scale Recommender System](https://arxiv.org/abs/1909.00385) |
3434
| MIND | [CIKM 2019][Multi-interest network with dynamic routing for recommendation at Tmall](https://arxiv.org/pdf/1904.08030) |
35+
| COMIREC | [KDD 2020][Controllable Multi-Interest Framework for Recommendation](https://arxiv.org/pdf/2005.09347.pdf) |
3536

3637
## Contributors([welcome to join us!](./CONTRIBUTING.md))
3738

@@ -60,6 +61,11 @@ Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.
6061
<a href="https://github.com/LeoCai">LeoCai</a>
6162
<p> ByteDance </p>​
6263
</td>
64+
<td>
65+
​ <a href="https://github.com/liyuan97"><img width="70" height="70" src="https://github.com/liyuan97.png?s=40" alt="pic"></a><br>
66+
​ <a href="https://github.com/liyuan97">Li Yuan</a>
67+
<p> Tencent </p>​
68+
</td>
6369
<td>
6470
​ <a href="https://github.com/yangjieyu"><img width="70" height="70" src="https://github.com/yangjieyu.png?s=40" alt="pic"></a><br>
6571
​ <a href="https://github.com/yangjieyu">Yang Jieyu</a>

deepmatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .utils import check_version
22

3-
__version__ = '0.3.0'
3+
__version__ = '0.3.1'
44
check_version(__version__)

deepmatch/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .ncf import NCF
55
from .mind import MIND
66
from .sdm import SDM
7+
from .comirec import ComiRec

deepmatch/models/comirec.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Author:
3+
4+
5+
Reference:
6+
Yukuo Cen, Jianwei Zhang, Xu Zou, et al. Controllable Multi-Interest Framework for Recommendation//Accepted to KDD 2020
7+
"""
8+
9+
import tensorflow as tf
10+
from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, \
11+
embedding_lookup, varlen_embedding_lookup, get_varlen_pooling_list, get_dense_input, build_input_features
12+
from deepctr.layers import DNN, PositionEncoding
13+
from deepctr.layers.utils import NoMask, combined_dnn_input, add_func
14+
from tensorflow.python.keras.layers import Concatenate, Lambda
15+
from tensorflow.python.keras.models import Model
16+
17+
from ..inputs import create_embedding_matrix
18+
from ..layers.core import CapsuleLayer, PoolingLayer, MaskUserEmbedding, LabelAwareAttention, SampledSoftmaxLayer, \
19+
EmbeddingIndex
20+
from ..layers.interaction import SoftmaxWeightedSum
21+
from ..utils import get_item_embedding
22+
23+
24+
def tile_user_otherfeat(user_other_feature, k_max):
25+
return tf.tile(tf.expand_dims(user_other_feature, -2), [1, k_max, 1])
26+
27+
28+
def tile_user_his_mask(hist_len, seq_max_len, k_max):
29+
return tf.tile(tf.sequence_mask(hist_len, seq_max_len), [1, k_max, 1])
30+
31+
32+
def softmax_Weighted_Sum(input):
33+
history_emb_add_pos, mask, attn = input[0], input[1], input[2]
34+
attn = tf.transpose(attn, [0, 2, 1])
35+
pad = tf.ones_like(mask, dtype=tf.float32) * (-2 ** 32 + 1)
36+
attn = tf.where(mask, attn, pad) # [batch_size, seq_len, num_interests]
37+
attn = tf.nn.softmax(attn) # [batch_size, seq_len, num_interests]
38+
high_capsule = tf.matmul(attn, history_emb_add_pos)
39+
return high_capsule
40+
41+
42+
def ComiRec(user_feature_columns, item_feature_columns, k_max=2, p=100, interest_extractor='sa',
43+
add_pos=True,
44+
user_dnn_hidden_units=(64, 32), dnn_activation='relu', dnn_use_bn=False, l2_reg_dnn=0,
45+
l2_reg_embedding=1e-6,
46+
dnn_dropout=0, output_activation='linear', sampler_config=None, seed=1024):
47+
"""Instantiates the ComiRec Model architecture.
48+
49+
:param user_feature_columns: An iterable containing user's features used by the model.
50+
:param item_feature_columns: An iterable containing item's features used by the model.
51+
:param k_max: int, the max size of user interest embedding
52+
:param p: float,the parameter for adjusting the attention distribution in LabelAwareAttention.
53+
:param interest_extractor: string, type of a multi-interest extraction module, 'sa' means self-attentive and 'dr' means dynamic routing
54+
:param add_pos: bool. Whether use positional encoding layer
55+
:param user_dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of user tower
56+
:param dnn_activation: Activation function to use in deep net
57+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
58+
:param l2_reg_dnn: L2 regularizer strength applied to DNN
59+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
60+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
61+
:param output_activation: Activation function to use in output layer
62+
:param sampler_config: negative sample config.
63+
:param seed: integer ,to use as random seed.
64+
:return: A Keras model instance.
65+
66+
"""
67+
68+
if len(item_feature_columns) > 1:
69+
raise ValueError("Now ComiRec only support 1 item feature like item_id")
70+
if interest_extractor.lower() not in ['dr', 'sa']:
71+
raise ValueError("Now ComiRec only support dr and sa two interest_extractor")
72+
item_feature_column = item_feature_columns[0]
73+
item_feature_name = item_feature_column.name
74+
item_vocabulary_size = item_feature_columns[0].vocabulary_size
75+
item_embedding_dim = item_feature_columns[0].embedding_dim
76+
if user_dnn_hidden_units[-1] != item_embedding_dim:
77+
user_dnn_hidden_units = tuple(list(user_dnn_hidden_units) + [item_embedding_dim])
78+
# item_index = Input(tensor=tf.constant([list(range(item_vocabulary_size))]))
79+
80+
history_feature_list = [item_feature_name]
81+
82+
features = build_input_features(user_feature_columns)
83+
sparse_feature_columns = list(
84+
filter(lambda x: isinstance(x, SparseFeat), user_feature_columns)) if user_feature_columns else []
85+
dense_feature_columns = list(
86+
filter(lambda x: isinstance(x, DenseFeat), user_feature_columns)) if user_feature_columns else []
87+
varlen_sparse_feature_columns = list(
88+
filter(lambda x: isinstance(x, VarLenSparseFeat), user_feature_columns)) if user_feature_columns else []
89+
history_feature_columns = []
90+
sparse_varlen_feature_columns = []
91+
history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list))
92+
for fc in varlen_sparse_feature_columns:
93+
feature_name = fc.name
94+
if feature_name in history_fc_names:
95+
history_feature_columns.append(fc)
96+
else:
97+
sparse_varlen_feature_columns.append(fc)
98+
seq_max_len = history_feature_columns[0].maxlen
99+
inputs_list = list(features.values())
100+
101+
embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding,
102+
seed=seed, prefix="")
103+
104+
item_features = build_input_features(item_feature_columns)
105+
106+
query_emb_list = embedding_lookup(embedding_matrix_dict, item_features, item_feature_columns,
107+
history_feature_list,
108+
history_feature_list, to_list=True)
109+
keys_emb_list = embedding_lookup(embedding_matrix_dict, features, history_feature_columns, history_fc_names,
110+
history_fc_names, to_list=True)
111+
dnn_input_emb_list = embedding_lookup(embedding_matrix_dict, features, sparse_feature_columns,
112+
mask_feat_list=history_feature_list, to_list=True)
113+
dense_value_list = get_dense_input(features, dense_feature_columns)
114+
115+
sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, sparse_varlen_feature_columns)
116+
sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, features, sparse_varlen_feature_columns,
117+
to_list=True)
118+
119+
dnn_input_emb_list += sequence_embed_list
120+
121+
# keys_emb = concat_func(keys_emb_list, mask=True)
122+
# query_emb = concat_func(query_emb_list, mask=True)
123+
124+
history_emb = PoolingLayer()(NoMask()(keys_emb_list)) # [None, max_len, emb_dim]
125+
target_emb = PoolingLayer()(NoMask()(query_emb_list))
126+
127+
# target_emb_size = target_emb.get_shape()[-1].value
128+
# max_len = history_emb.get_shape()[1].value
129+
hist_len = features['hist_len']
130+
131+
high_capsule = None
132+
if interest_extractor.lower() == 'dr':
133+
high_capsule = CapsuleLayer(input_units=item_embedding_dim,
134+
out_units=item_embedding_dim, max_len=seq_max_len,
135+
k_max=k_max)((history_emb, hist_len))
136+
elif interest_extractor.lower() == 'sa':
137+
history_emb_add_pos = history_emb
138+
if add_pos:
139+
position_embedding = PositionEncoding()(history_emb)
140+
history_emb_add_pos = add_func([history_emb_add_pos, position_embedding]) # [None, max_len, emb_dim]
141+
142+
attn = DNN((item_embedding_dim * 4, k_max), activation='tanh', l2_reg=l2_reg_dnn,
143+
dropout_rate=dnn_dropout, use_bn=dnn_use_bn, output_activation=None, seed=seed,
144+
name="user_dnn_attn")(history_emb_add_pos)
145+
mask = Lambda(tile_user_his_mask, arguments={'k_max': k_max,
146+
'seq_max_len': seq_max_len})(
147+
hist_len) # [None, k_max, max_len]
148+
149+
high_capsule = Lambda(softmax_Weighted_Sum)((history_emb_add_pos, mask, attn))
150+
151+
if len(dnn_input_emb_list) > 0 or len(dense_value_list) > 0:
152+
user_other_feature = combined_dnn_input(dnn_input_emb_list, dense_value_list)
153+
other_feature_tile = Lambda(tile_user_otherfeat, arguments={'k_max': k_max})(user_other_feature)
154+
user_deep_input = Concatenate()([NoMask()(other_feature_tile), high_capsule])
155+
else:
156+
user_deep_input = high_capsule
157+
158+
user_embeddings = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn,
159+
dnn_dropout, dnn_use_bn, output_activation=output_activation, seed=seed,
160+
name="user_dnn")(
161+
user_deep_input)
162+
163+
item_inputs_list = list(item_features.values())
164+
165+
item_embedding_matrix = embedding_matrix_dict[item_feature_name]
166+
167+
item_index = EmbeddingIndex(list(range(item_vocabulary_size)))(item_features[item_feature_name])
168+
169+
item_embedding_weight = NoMask()(item_embedding_matrix(item_index))
170+
171+
pooling_item_embedding_weight = PoolingLayer()([item_embedding_weight])
172+
173+
user_embedding_final = LabelAwareAttention(k_max=k_max, pow_p=p)((user_embeddings, target_emb))
174+
175+
output = SampledSoftmaxLayer(sampler_config._asdict())(
176+
[pooling_item_embedding_weight, user_embedding_final, item_features[item_feature_name]])
177+
model = Model(inputs=inputs_list + item_inputs_list, outputs=output)
178+
179+
model.__setattr__("user_input", inputs_list)
180+
model.__setattr__("user_embedding", user_embeddings)
181+
182+
model.__setattr__("item_input", item_inputs_list)
183+
model.__setattr__("item_embedding",
184+
get_item_embedding(pooling_item_embedding_weight, item_features[item_feature_name]))
185+
186+
return model

deepmatch/models/mind.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,9 @@ def MIND(user_feature_columns, item_feature_columns, k_max=2, p=100, dynamic_k=F
5353
5454
:param user_feature_columns: An iterable containing user's features used by the model.
5555
:param item_feature_columns: An iterable containing item's features used by the model.
56-
:param num_sampled: int, the number of classes to randomly sample per batch.
5756
:param k_max: int, the max size of user interest embedding
5857
:param p: float,the parameter for adjusting the attention distribution in LabelAwareAttention.
5958
:param dynamic_k: bool, whether or not use dynamic interest number
60-
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
6159
:param user_dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of user tower
6260
:param dnn_activation: Activation function to use in deep net
6361
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
@@ -169,7 +167,7 @@ def MIND(user_feature_columns, item_feature_columns, k_max=2, p=100, dynamic_k=F
169167
user_embedding_final = LabelAwareAttention(k_max=k_max, pow_p=p)((user_embeddings, target_emb, interest_num))
170168
else:
171169
user_embedding_final = LabelAwareAttention(k_max=k_max, pow_p=p)((user_embeddings, target_emb))
172-
print("swc")
170+
173171
output = SampledSoftmaxLayer(sampler_config._asdict())(
174172
[pooling_item_embedding_weight, user_embedding_final, item_features[item_feature_name]])
175173
model = Model(inputs=inputs_list + item_inputs_list, outputs=output)

deepmatch/models/sdm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def SDM(user_feature_columns, item_feature_columns, history_feature_list, units=
3030
:param user_feature_columns: An iterable containing user's features used by the model.
3131
:param item_feature_columns: An iterable containing item's features used by the model.
3232
:param history_feature_list: list,to indicate short and prefer sequence sparse field
33-
:param num_sampled: int, the number of classes to randomly sample per batch.
3433
:param units: int, dimension for each output layer
3534
:param rnn_layers: int, layer number of rnn
3635
:param dropout_rate: float in [0,1), the probability we will drop out a given DNN coordinate.

0 commit comments

Comments
 (0)