Skip to content

Commit d5627fd

Browse files
task.py
1 parent 99e54b2 commit d5627fd

File tree

1 file changed

+14
-29
lines changed
  • official/recommendation/ranking

1 file changed

+14
-29
lines changed

official/recommendation/ranking/task.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def _get_tpu_embedding_feature_config(
6767
table_config = tf.tpu.experimental.embedding.TableConfig(
6868
vocabulary_size=vocab_size,
6969
dim=embedding_dim[i],
70-
combiner='mean',
71-
initializer=tf.initializers.TruncatedNormal(
72-
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
70+
combiner='sum',
71+
initializer=tf.initializers.RandomUniform(
72+
minval= - 1.0 / math.sqrt(vocab_size, maxval = 1.0 / math.sqrt(vocab_size))),
7373
name=table_name_prefix + '_%02d' % i)
7474
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
7575
name=str(i),
@@ -149,29 +149,17 @@ def build_model(self) -> tf_keras.Model:
149149
A Ranking model instance.
150150
"""
151151
lr_config = self.optimizer_config.lr_config
152-
lr_callable = common.WarmUpAndPolyDecay(
153-
batch_size=self.task_config.train_data.global_batch_size,
154-
decay_exp=lr_config.decay_exp,
152+
embedding_optimizer = tf.kears.optimizers.legacy.Adagrad(
155153
learning_rate=lr_config.learning_rate,
156-
warmup_steps=lr_config.warmup_steps,
157-
decay_steps=lr_config.decay_steps,
158-
decay_start_steps=lr_config.decay_start_steps)
159-
embedding_optimizer = tf_keras.optimizers.get(
160-
self.optimizer_config.embedding_optimizer, use_legacy_optimizer=True)
161-
embedding_optimizer.learning_rate = lr_callable
162-
163-
dense_optimizer = tf_keras.optimizers.get(
164-
self.optimizer_config.dense_optimizer, use_legacy_optimizer=True)
165-
if self.optimizer_config.dense_optimizer == 'SGD':
166-
dense_lr_config = self.optimizer_config.dense_sgd_config
167-
dense_lr_callable = common.WarmUpAndPolyDecay(
168-
batch_size=self.task_config.train_data.global_batch_size,
169-
decay_exp=dense_lr_config.decay_exp,
170-
learning_rate=dense_lr_config.learning_rate,
171-
warmup_steps=dense_lr_config.warmup_steps,
172-
decay_steps=dense_lr_config.decay_steps,
173-
decay_start_steps=dense_lr_config.decay_start_steps)
174-
dense_optimizer.learning_rate = dense_lr_callable
154+
initial_accumulator_value=lr_config.initial_accumulator_value,
155+
epsilon=lr_config.epsilon,
156+
)
157+
158+
dense_optimizer = tf.kears.optimizers.legacy.Adagrad(
159+
learning_rate=lr_config.learning_rate,
160+
initial_accumulator_value=lr_config.initial_accumulator_value,
161+
epsilon=lr_config.epsilon,
162+
)
175163

176164
feature_config = _get_tpu_embedding_feature_config(
177165
embedding_dim=self.task_config.model.embedding_dim,
@@ -208,9 +196,6 @@ def build_model(self) -> tf_keras.Model:
208196
tfrs.layers.feature_interaction.MultiLayerDCN(
209197
projection_dim=self.task_config.model.dcn_low_rank_dim,
210198
num_layers=self.task_config.model.dcn_num_layers,
211-
use_bias=self.task_config.model.dcn_use_bias,
212-
kernel_initializer=self.task_config.model.dcn_kernel_initializer,
213-
bias_initializer=self.task_config.model.dcn_bias_initializer,
214199
),
215200
])
216201
else:
@@ -226,7 +211,7 @@ def build_model(self) -> tf_keras.Model:
226211
),
227212
feature_interaction=feature_interaction,
228213
top_stack=tfrs.layers.blocks.MLP(
229-
units=self.task_config.model.top_mlp, final_activation='sigmoid'
214+
units=self.task_config.model.top_mlp
230215
),
231216
concat_dense=self.task_config.model.concat_dense,
232217
)

0 commit comments

Comments
 (0)