Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Lower memory #69

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from tensorflow.python import debug as tf_debug
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python import pywrap_tensorflow
import fastBPE
import platform

use_py3 = platform.python_version()[0] == '3'

parser = argparse.ArgumentParser(description='TensorFlow code for generating from CTRL')
parser.add_argument('--model_dir', type=str, required=True,
help='location of model checkpoint')
parser.add_argument('--model_path', type=str, required=True,
help='location of model *data* checkpoint; this is NOT the directory but rather the model checkpoint')
parser.add_argument('--seed', type=int, default=1337,
help='random seed for TensorFlow, numpy and PythonHash')
parser.add_argument('--generate_num', type=int, default=256,
Expand All @@ -33,6 +34,10 @@
help='topk value for sampling from the softmax distribution ; 0 means no topk preferred')
parser.add_argument('--penalty', type=float, default=1.2,
help='repetition penalty for greedy sampling')
parser.add_argument('--print_once', action='store_true',
help='the completion is printed only at the end; not every word')
parser.add_argument('--topn', type=int, default=0,
help='print top-n candidates during generations; defaults to 0 which is no printing')

args = parser.parse_args()
tf.random.set_random_seed(args.seed)
Expand Down Expand Up @@ -75,10 +80,10 @@ class TiedEmbeddingSoftmax(tf.keras.layers.Layer):

def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
super(TiedEmbeddingSoftmax, self).__init__()
self.w = self.add_weight(name='w', shape=(vocab_size, embedding_size),
self.w = self.add_weight(name='w', shape=(vocab_size, embedding_size), dtype=tf.float32,
initializer='random_normal',
trainable=True)
self.b = self.add_weight(name='b', shape=(vocab_size,),
self.b = self.add_weight(name='b', shape=(vocab_size,), dtype=tf.float32,
initializer='zeros',
trainable=True)

Expand Down Expand Up @@ -132,26 +137,6 @@ def loss(labels, logits):
print(model.summary())


# IMPORTANT
# this is where the saved model is presented to the code
# the model directory should have the model checkpoint and
# a checkpoint file
run_config = tf.contrib.tpu.RunConfig(
model_dir=args.model_dir)


# this converts the Keras model to a TensorFlow estimator
# this step is critical
# remember to patch the TF 1.14 file before running the code, else you're going to see errors here
estimator_model = tf.keras.estimator.model_to_estimator(keras_model=model, config=run_config)

# we now create a serving function from this estimator
# this enables us to load the model once and easily query it multiple times
def serving_input_fn():
inputs = {'input_1': tf.placeholder(tf.int32, [1,seq_length])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
predict_fn = tf.contrib.predictor.from_estimator(estimator_model, serving_input_fn)

# almost there, we now take the user prompt and tokenize with BPE
# load BPE codes
bpe = fastBPE.fastBPE('codes', 'vocab')
Expand All @@ -161,7 +146,29 @@ def serving_input_fn():
penalty = args.penalty
topk = args.topk

# Load the model file
chkpt_for_reader = '.'.join(args.model_path.split('.')[:-1])
reader = pywrap_tensorflow.NewCheckpointReader(chkpt_for_reader)

# assign weights from the checkpoint to the Keras model
# this is super hacky but I couldn't find a better way to do this
# PR is highly welcome if you know of a better way

# embedding and softmax
# these are fp32
model.layers[1].trainable_variables[0].assign(tf.cast(reader.get_tensor('w'), tf.float32))
model.layers[1].trainable_variables[1].assign(tf.cast(reader.get_tensor('b'), tf.float32))

# encoder weights
for _ in range(len(model.layers[2].trainable_weights)):
tensor = model.layers[2].trainable_weights[_]
if 'normalization' in tensor.name[:-2]: # layernorm is fp32
tensor.assign(tf.cast(reader.get_tensor(tensor.name[:-2]), tf.float32))
else: # everything else is fp16
tensor.assign(tf.cast(reader.get_tensor(tensor.name[:-2]), tf.float16))

while True:
print('WARNING! THIS VERSION OF THE CODE ALLOWS FOR LOWER MEMORY USAGE THROUGH FP16 QUANTIZATION BUT IS UNTESTED; GENERATIONS MAY BE WORSE. USE AT YOUR OWN RISK. ')
prompt = raw_input('ENTER PROMPT: ') if not use_py3 else input('ENTER PROMPT: ')

# tokenize provided prompt
Expand All @@ -178,13 +185,13 @@ def serving_input_fn():
# this is done by sliding the window over (past 512 tokens) and continuing prediction
# I'm sure this can be simplified (TODO)
if token <= seq_length:
prompt_logits = predict_fn({'input_1':tokens_generated[:, :seq_length]})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.)
prompt_logits = model.predict_on_batch(tokens_generated[:, :seq_length]).squeeze() / (temperature if temperature>0 else 1.)
_token = token if token < seq_length else -1
else:
_token = -1
end = token + 1
start = token - seq_length + 2
prompt_logits = predict_fn({'input_1':np.hstack((tokens_generated[:,0:1], tokens_generated[:,start:end]))})['tied_embedding_softmax'].squeeze() / (temperature if temperature>0 else 1.)
prompt_logits = model.predict_on_batch(np.hstack((tokens_generated[:,0:1], tokens_generated[:,start:end]))).squeeze() / (temperature if temperature>0 else 1.)


# if penalty (for repetition) is non-zero,
Expand Down Expand Up @@ -233,6 +240,8 @@ def serving_input_fn():
# then we will use the whole list
nucleus = len(pruned_list)

pruned_list = pruned_list[:nucleus]

# if you want to disallow more complex tokens, you can do so here
# for instance, if you want to disallow anything with the phrase `http`,
# you can delete theme from the pruned_list
Expand All @@ -243,23 +252,21 @@ def serving_input_fn():
tokens_to_disallow.append(_)
pruned_list = np.delete(pruned_list, tokens_to_disallow)

if args.topn > 0 :
print('TOPN :: top-n alternatives:', [idx2word[_] for _ in pruned_list[:args.topn]])

# if temperature is 0
# just pick the first (most probable) token
if temperature==0:
idx = pruned_list[0]
else:
# else,
# sample from the pruned_list with the logits
chosen_idx = int(tf.random.categorical(np.expand_dims(prompt_logits[0][_token][pruned_list],0), num_samples=1).numpy())
chosen_idx = int(tf.random.categorical(np.expand_dims(prompt_logits[_token][pruned_list],0), num_samples=1).numpy())
idx = pruned_list[chosen_idx]


# if you want to do some debugging,
# like which one was chosen,
# what the top25 were,
# here is your opportunity.
#print('chosen:', idx2word[idx])
#print('top25 alternatives:', pruned_list[:25])
if args.topn > 0 :
print('TOPN :: chosen word:', idx2word[idx])

# assign the token for generation
tokens_generated[0][token+1] = idx
Expand All @@ -270,10 +277,15 @@ def serving_input_fn():

tokens_generated_so_far = ' '.join([idx2word[c] for c in tokens_generated[0].squeeze()[:token+2]])
tokens_generated_so_far = re.sub('(@@ )', '', string=tokens_generated_so_far)
tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
print(tokens_generated_so_far)
print()

tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
if not args.print_once:
print('---------------------------------------')
print(tokens_generated_so_far)
print()
print('---------------------------------------')
print(tokens_generated_so_far)
print()


except KeyboardInterrupt: #Exception as e:
print('Continuing')
Expand Down
18 changes: 10 additions & 8 deletions transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def positional_encoding(position, d_model_size):

def scaled_dot_product_attention(q, k, v, mask):
# calculate attention
matmul_qk = tf.matmul(q, k, transpose_b=True)
matmul_qk = tf.cast(tf.matmul(q, k, transpose_b=True), tf.float32)

dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

if mask is not None:
scaled_attention_logits += (mask * -1e9)
scaled_attention_logits += (mask * -1e3)

attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
attention_weights = tf.cast(tf.nn.softmax(scaled_attention_logits, axis=-1), tf.float16)
output = tf.matmul(attention_weights, v)
return output

Expand Down Expand Up @@ -85,14 +85,16 @@ def __init__(self, d_model_size, num_heads, dff, rate=0.1):

self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.to32 = lambda x: tf.cast(x, tf.float32)
self.to16 = lambda x: tf.cast(x, tf.float16)

def call(self, x, training, mask):
normed = self.layernorm1(x)
normed = self.to16(self.layernorm1(self.to32(x)))
attn_output = self.multi_head_attention(normed, normed, normed, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = x + attn_output

out2 = self.layernorm2(out1)
out2 = self.to16(self.layernorm2(self.to32(out1)))
ffn_output = self.ffn(out2)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = out1 + ffn_output
Expand Down Expand Up @@ -123,7 +125,6 @@ def get_config(self):
return base_config

def call(self, x, training):

seq_len = tf.shape(x)[1]

mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
Expand All @@ -132,8 +133,9 @@ def call(self, x, training):
x += self.pos_encoding[:, :seq_len, :]

x = self.dropout(x, training=training)
x = tf.cast(x, tf.float16)

for i in range(self.num_layers):
x = getattr(self, "layer%i" % i)(x, training, mask)
return self.layernorm(x)
return self.layernorm(tf.cast(x, tf.float32))