diff --git a/generation.py b/generation.py index 45ce64e..69a100b 100644 --- a/generation.py +++ b/generation.py @@ -13,14 +13,15 @@ from tensorflow.python import debug as tf_debug from tensorflow.python.ops import math_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python import pywrap_tensorflow import fastBPE import platform use_py3 = platform.python_version()[0] == '3' parser = argparse.ArgumentParser(description='TensorFlow code for generating from CTRL') -parser.add_argument('--model_dir', type=str, required=True, - help='location of model checkpoint') +parser.add_argument('--model_path', type=str, required=True, + help='location of model *data* checkpoint; this is NOT the directory but rather the model checkpoint') parser.add_argument('--seed', type=int, default=1337, help='random seed for TensorFlow, numpy and PythonHash') parser.add_argument('--generate_num', type=int, default=256, @@ -33,6 +34,10 @@ help='topk value for sampling from the softmax distribution ; 0 means no topk preferred') parser.add_argument('--penalty', type=float, default=1.2, help='repetition penalty for greedy sampling') +parser.add_argument('--print_once', action='store_true', + help='the completion is printed only at the end; not every word') +parser.add_argument('--topn', type=int, default=0, + help='print top-n candidates during generations; defaults to 0 which is no printing') args = parser.parse_args() tf.random.set_random_seed(args.seed) @@ -75,10 +80,10 @@ class TiedEmbeddingSoftmax(tf.keras.layers.Layer): def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs): super(TiedEmbeddingSoftmax, self).__init__() - self.w = self.add_weight(name='w', shape=(vocab_size, embedding_size), + self.w = self.add_weight(name='w', shape=(vocab_size, embedding_size), dtype=tf.float32, initializer='random_normal', trainable=True) - self.b = self.add_weight(name='b', shape=(vocab_size,), + self.b = self.add_weight(name='b', shape=(vocab_size,), dtype=tf.float32, initializer='zeros', trainable=True) @@ -132,26 +137,6 @@ def loss(labels, logits): print(model.summary()) -# IMPORTANT -# this is where the saved model is presented to the code -# the model directory should have the model checkpoint and -# a checkpoint file -run_config = tf.contrib.tpu.RunConfig( - model_dir=args.model_dir) - - -# this converts the Keras model to a TensorFlow estimator -# this step is critical -# remember to patch the TF 1.14 file before running the code, else you're going to see errors here -estimator_model = tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config) - -# we now create a serving function from this estimator -# this enables us to load the model once and easily query it multiple times -def serving_input_fn(): - inputs = {'input_1': tf.placeholder(tf.int32, [1,seq_length])} - return tf.estimator.export.ServingInputReceiver(inputs, inputs) -predict_fn = tf.contrib.predictor.from_estimator(estimator_model, serving_input_fn) - # almost there, we now take the user prompt and tokenize with BPE # load BPE codes bpe = fastBPE.fastBPE('codes', 'vocab') @@ -161,7 +146,29 @@ def serving_input_fn(): penalty = args.penalty topk = args.topk +# Load the model file +chkpt_for_reader = '.'.join(args.model_path.split('.')[:-1]) +reader = pywrap_tensorflow.NewCheckpointReader(chkpt_for_reader) + +# assign weights from the checkpoint to the Keras model +# this is super hacky but I couldn't find a better way to do this +# PR is highly welcome if you know of a better way + +# embedding and softmax +# these are fp32 +model.layers[1].trainable_variables[0].assign(tf.cast(reader.get_tensor('w'), tf.float32)) +model.layers[1].trainable_variables[1].assign(tf.cast(reader.get_tensor('b'), tf.float32)) + +# encoder weights +for _ in range(len(model.layers[2].trainable_weights)): + tensor = model.layers[2].trainable_weights[_] + if 'normalization' in tensor.name[:-2]: # layernorm is fp32 + tensor.assign(tf.cast(reader.get_tensor(tensor.name[:-2]), tf.float32)) + else: # everything else is fp16 + tensor.assign(tf.cast(reader.get_tensor(tensor.name[:-2]), tf.float16)) + while True: + print('WARNING! THIS VERSION OF THE CODE ALLOWS FOR LOWER MEMORY USAGE THROUGH FP16 QUANTIZATION BUT IS UNTESTED; GENERATIONS MAY BE WORSE. USE AT YOUR OWN RISK. ') prompt = raw_input('ENTER PROMPT: ') if not use_py3 else input('ENTER PROMPT: ') # tokenize provided prompt @@ -178,13 +185,13 @@ def serving_input_fn(): # this is done by sliding the window over (past 512 tokens) and continuing prediction # I'm sure this can be simplified (TODO) if token <= seq_length: - prompt_logits = predict_fn({'input_1':tokens_generated[:, :seq_length]})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.) + prompt_logits = model.predict_on_batch(tokens_generated[:, :seq_length]).squeeze() / (temperature if temperature>0 else 1.) _token = token if token < seq_length else -1 else: _token = -1 end = token + 1 start = token - seq_length + 2 - prompt_logits = predict_fn({'input_1':np.hstack((tokens_generated[:,0:1], tokens_generated[:,start:end]))})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.) + prompt_logits = model.predict_on_batch(np.hstack((tokens_generated[:,0:1], tokens_generated[:,start:end]))).squeeze() / (temperature if temperature>0 else 1.) # if penalty (for repetition) is non-zero, @@ -233,6 +240,8 @@ def serving_input_fn(): # then we will use the whole list nucleus = len(pruned_list) + pruned_list = pruned_list[:nucleus] + # if you want to disallow more complex tokens, you can do so here # for instance, if you want to disallow anything with the phrase `http`, # you can delete theme from the pruned_list @@ -243,6 +252,9 @@ def serving_input_fn(): tokens_to_disallow.append(_) pruned_list = np.delete(pruned_list, tokens_to_disallow) + if args.topn > 0 : + print('TOPN :: top-n alternatives:', [idx2word[_] for _ in pruned_list[:args.topn]]) + # if temperature is 0 # just pick the first (most probable) token if temperature==0: @@ -250,16 +262,11 @@ def serving_input_fn(): else: # else, # sample from the pruned_list with the logits - chosen_idx = int(tf.random.categorical(np.expand_dims(prompt_logits[0][_token][pruned_list],0), num_samples=1).numpy()) + chosen_idx = int(tf.random.categorical(np.expand_dims(prompt_logits[_token][pruned_list],0), num_samples=1).numpy()) idx = pruned_list[chosen_idx] - - # if you want to do some debugging, - # like which one was chosen, - # what the top25 were, - # here is your opportunity. - #print('chosen:', idx2word[idx]) - #print('top25 alternatives:', pruned_list[:25]) + if args.topn > 0 : + print('TOPN :: chosen word:', idx2word[idx]) # assign the token for generation tokens_generated[0][token+1] = idx @@ -270,10 +277,15 @@ def serving_input_fn(): tokens_generated_so_far = ' '.join([idx2word[c] for c in tokens_generated[0].squeeze()[:token+2]]) tokens_generated_so_far = re.sub('(@@ )', '', string=tokens_generated_so_far) - tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) - print(tokens_generated_so_far) - print() - + tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) + if not args.print_once: + print('---------------------------------------') + print(tokens_generated_so_far) + print() + print('---------------------------------------') + print(tokens_generated_so_far) + print() + except KeyboardInterrupt: #Exception as e: print('Continuing') diff --git a/transformer.py b/transformer.py index d1e8592..ec9ec89 100644 --- a/transformer.py +++ b/transformer.py @@ -19,15 +19,15 @@ def positional_encoding(position, d_model_size): def scaled_dot_product_attention(q, k, v, mask): # calculate attention - matmul_qk = tf.matmul(q, k, transpose_b=True) + matmul_qk = tf.cast(tf.matmul(q, k, transpose_b=True), tf.float32) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) if mask is not None: - scaled_attention_logits += (mask * -1e9) + scaled_attention_logits += (mask * -1e3) - attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) + attention_weights = tf.cast(tf.nn.softmax(scaled_attention_logits, axis=-1), tf.float16) output = tf.matmul(attention_weights, v) return output @@ -85,14 +85,16 @@ def __init__(self, d_model_size, num_heads, dff, rate=0.1): self.dropout1 = tf.keras.layers.Dropout(rate) self.dropout2 = tf.keras.layers.Dropout(rate) + self.to32 = lambda x: tf.cast(x, tf.float32) + self.to16 = lambda x: tf.cast(x, tf.float16) def call(self, x, training, mask): - normed = self.layernorm1(x) + normed = self.to16(self.layernorm1(self.to32(x))) attn_output = self.multi_head_attention(normed, normed, normed, mask) attn_output = self.dropout1(attn_output, training=training) out1 = x + attn_output - - out2 = self.layernorm2(out1) + + out2 = self.to16(self.layernorm2(self.to32(out1))) ffn_output = self.ffn(out2) ffn_output = self.dropout2(ffn_output, training=training) out2 = out1 + ffn_output @@ -123,7 +125,6 @@ def get_config(self): return base_config def call(self, x, training): - seq_len = tf.shape(x)[1] mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) @@ -132,8 +133,9 @@ def call(self, x, training): x += self.pos_encoding[:, :seq_len, :] x = self.dropout(x, training=training) + x = tf.cast(x, tf.float16) for i in range(self.num_layers): x = getattr(self, "layer%i" % i)(x, training, mask) - return self.layernorm(x) + return self.layernorm(tf.cast(x, tf.float32))