13
13
from deepctr .layers .utils import NoMask , combined_dnn_input , add_func
14
14
from tensorflow .python .keras .layers import Concatenate , Lambda
15
15
from tensorflow .python .keras .models import Model
16
+
16
17
from ..inputs import create_embedding_matrix
17
- from ..layers .core import CapsuleLayer , PoolingLayer , LabelAwareAttention , SampledSoftmaxLayer , EmbeddingIndex
18
+ from ..layers .core import CapsuleLayer , PoolingLayer , MaskUserEmbedding , LabelAwareAttention , SampledSoftmaxLayer , \
19
+ EmbeddingIndex
18
20
from ..layers .interaction import SoftmaxWeightedSum
19
21
from ..utils import get_item_embedding
20
22
21
23
22
- def tile_user_otherfeat (user_other_feature , interest_num ):
23
- return tf .tile (tf .expand_dims (user_other_feature , - 2 ), [1 , interest_num , 1 ])
24
+ def tile_user_otherfeat (user_other_feature , k_max ):
25
+ return tf .tile (tf .expand_dims (user_other_feature , - 2 ), [1 , k_max , 1 ])
24
26
25
27
26
- def tile_user_his_mask (hist_len , seq_max_len , interest_num ):
27
- return tf .tile (tf .sequence_mask (hist_len , seq_max_len ), [1 , interest_num , 1 ])
28
+ def tile_user_his_mask (hist_len , seq_max_len , k_max ):
29
+ return tf .tile (tf .sequence_mask (hist_len , seq_max_len ), [1 , k_max , 1 ])
28
30
29
31
30
32
def softmax_Weighted_Sum (input ):
@@ -37,20 +39,19 @@ def softmax_Weighted_Sum(input):
37
39
return high_capsule
38
40
39
41
40
- def ComiRec (user_feature_columns , item_feature_columns , interest_num = 2 , p = 100 , interest_extractor = 'sa' , add_pos = False ,
42
+ def ComiRec (user_feature_columns , item_feature_columns , k_max = 2 , p = 100 , interest_extractor = 'sa' ,
43
+ add_pos = True ,
41
44
user_dnn_hidden_units = (64 , 32 ), dnn_activation = 'relu' , dnn_use_bn = False , l2_reg_dnn = 0 ,
42
45
l2_reg_embedding = 1e-6 ,
43
46
dnn_dropout = 0 , output_activation = 'linear' , sampler_config = None , seed = 1024 ):
44
47
"""Instantiates the ComiRec Model architecture.
45
48
46
49
:param user_feature_columns: An iterable containing user's features used by the model.
47
50
:param item_feature_columns: An iterable containing item's features used by the model.
48
- :param num_sampled: int, the number of classes to randomly sample per batch.
49
- :param interest_num: int, the max size of user interest embedding
51
+ :param k_max: int, the max size of user interest embedding
50
52
:param p: float,the parameter for adjusting the attention distribution in LabelAwareAttention.
51
53
:param interest_extractor: string, type of a multi-interest extraction module, 'sa' means self-attentive and 'dr' means dynamic routing
52
54
:param add_pos: bool. Whether use positional encoding layer
53
- :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
54
55
:param user_dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of user tower
55
56
:param dnn_activation: Activation function to use in deep net
56
57
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net
@@ -131,29 +132,25 @@ def ComiRec(user_feature_columns, item_feature_columns, interest_num=2, p=100, i
131
132
if interest_extractor .lower () == 'dr' :
132
133
high_capsule = CapsuleLayer (input_units = item_embedding_dim ,
133
134
out_units = item_embedding_dim , max_len = seq_max_len ,
134
- k_max = interest_num )((history_emb , hist_len ))
135
+ k_max = k_max )((history_emb , hist_len ))
135
136
elif interest_extractor .lower () == 'sa' :
136
137
history_emb_add_pos = history_emb
137
138
if add_pos :
138
139
position_embedding = PositionEncoding ()(history_emb )
139
140
history_emb_add_pos = add_func ([history_emb_add_pos , position_embedding ]) # [None, max_len, emb_dim]
140
141
141
- attn = DNN ((item_embedding_dim * 4 , interest_num ), activation = 'tanh' , l2_reg = l2_reg_dnn ,
142
+ attn = DNN ((item_embedding_dim * 4 , k_max ), activation = 'tanh' , l2_reg = l2_reg_dnn ,
142
143
dropout_rate = dnn_dropout , use_bn = dnn_use_bn , output_activation = None , seed = seed ,
143
144
name = "user_dnn_attn" )(history_emb_add_pos )
144
- mask = Lambda (tile_user_his_mask , arguments = {'interest_num ' : interest_num ,
145
+ mask = Lambda (tile_user_his_mask , arguments = {'k_max ' : k_max ,
145
146
'seq_max_len' : seq_max_len })(
146
- hist_len ) # [None, interest_num, max_len]
147
- # high_capsule = SoftmaxWeightedSum(dropout_rate=0, future_binding=False,
148
- # seed=seed)([attn, history_emb_add_pos, mask])
147
+ hist_len ) # [None, k_max, max_len]
148
+
149
149
high_capsule = Lambda (softmax_Weighted_Sum )((history_emb_add_pos , mask , attn ))
150
150
151
- print ("high_capsule" ,
152
- high_capsule ) # Tensor("softmax_weighted_sum/MatMul:0", shape=(None, 2, 32), dtype=float32) Tensor("capsule_layer/Reshape_1:0", shape=(None, 2, 32), dtype=float32)
153
151
if len (dnn_input_emb_list ) > 0 or len (dense_value_list ) > 0 :
154
152
user_other_feature = combined_dnn_input (dnn_input_emb_list , dense_value_list )
155
- other_feature_tile = Lambda (tile_user_otherfeat , arguments = {'interest_num' : interest_num })(user_other_feature )
156
- print ("other_feature_tile" , other_feature_tile , "NoMask" , NoMask ()(other_feature_tile ))
153
+ other_feature_tile = Lambda (tile_user_otherfeat , arguments = {'k_max' : k_max })(user_other_feature )
157
154
user_deep_input = Concatenate ()([NoMask ()(other_feature_tile ), high_capsule ])
158
155
else :
159
156
user_deep_input = high_capsule
@@ -173,7 +170,8 @@ def ComiRec(user_feature_columns, item_feature_columns, interest_num=2, p=100, i
173
170
174
171
pooling_item_embedding_weight = PoolingLayer ()([item_embedding_weight ])
175
172
176
- user_embedding_final = LabelAwareAttention (k_max = interest_num , pow_p = p )((user_embeddings , target_emb ))
173
+ user_embedding_final = LabelAwareAttention (k_max = k_max , pow_p = p )((user_embeddings , target_emb ))
174
+
177
175
output = SampledSoftmaxLayer (sampler_config ._asdict ())(
178
176
[pooling_item_embedding_weight , user_embedding_final , item_features [item_feature_name ]])
179
177
model = Model (inputs = inputs_list + item_inputs_list , outputs = output )
0 commit comments