Skip to content

Commit

Permalink
Add support for training BERT models (with examples)
Browse files Browse the repository at this point in the history
  • Loading branch information
kpot committed Dec 10, 2018
1 parent 1bbd5b2 commit 7e937ff
Show file tree
Hide file tree
Showing 16 changed files with 1,272 additions and 377 deletions.
68 changes: 51 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
Keras-Transformer
=================

Keras-transformer it's a library implementing nuts and bolts for
building (Universal) Transformer models using Keras. It allows you
to assemble a multi-step Transformer model in a flexible way, for example:
Keras-transformer is a Python library implementing nuts and bolts,
for building (Universal) Transformer models using [Keras](http://keras.io),
and equipped with [examples](#language-modelling-examples-with-bert-and-gpt)
of how it can be applied.

The library supports:

* positional encoding and embeddings,
* attention masking,
* memory-compressed attention,
* ACT (adaptive computation time),
* a general implementation of [BERT][3] (because the Transformer
is mainly applied to NLP tasks).

It allows you to piece together a multi-step Transformer model
in a flexible way, for example:

```python
transformer_block = TransformerBlock(
Expand All @@ -21,11 +34,11 @@ for step in range(transformer_depth):
add_coordinate_embedding(input, step=step))
```

The library supports positional encoding and embeddings,
attention masking, memory-compressed attention, ACT (adaptive computation time).
All pieces of the model (like self-attention, activation function, layer normalization)
are available as Keras layers, so, if necessary, you can build your
version of Transformer, by re-arranging them differently or replacing some of them.

All pieces of the model (like self-attention, activation function,
layer normalization) are available as Keras layers, so, if necessary,
you can build your version of Transformer, by re-arranging them
differently or replacing some of them.

The (Universal) Transformer is a deep learning architecture
described in arguably one of the most impressive DL papers of 2017 and 2018:
Expand All @@ -48,16 +61,26 @@ then switch to the cloned directory and run pip
cd keras-transformer
pip install .

Language modelling example
--------------------------
This repository contains a simple [example](./example) showing how Keras-transformer works.
Please note that the project requires Python >= 3.6.

Language modelling examples with BERT and GPT
---------------------------------------------
This repository contains simple [examples](./example) showing how
Keras-transformer works.
It's not a rigorous evaluation of the model's capabilities,
but rather a demonstration on how to use the code.

The code trains a simple language-modeling network on the
The code trains [simple language-modeling networks](./example/models.py) on the
[WikiText-2](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset)
dataset and evaluates its perplexity.
The model itself is an Adaptive Universal Transformer with five layers.
dataset and evaluates their perplexity. The model is either a [vanilla
Transformer][1], or an [Adaptive Universal Transformer][2] (by default)
with five layers, each can be trained using either:

* [Generative pre-training][4] (GPT), which involves using masked self-attention
to prevent the model from "looking into the future".
* [BERT][3], which doesn't restrict self-attention, allowing the model
to fill the gaps using both left and right context.


To launch the code, you will first need to install the requirements listed
in [example/requirements.txt](./example/requirements.txt). Assuming you work
Expand All @@ -71,13 +94,13 @@ Tensorflow and PlaidML as backends):

pip install tensorflow

Now you can launch the example itself as
Now you can launch the GPT example as

python -m example.run --save lm_model.h5
python -m example.run_gpt --save lm_model.h5

to see all command line options and their default values, try

python -m example.run --help
python -m example.run_gpt --help

If all goes well, after launching the example you should see
the perplexity falling with each epoch.
Expand All @@ -95,5 +118,16 @@ After 200 epochs (~5 hours) of training on GeForce 1080 Ti, I've got
validation perplexity about 51.61 and test perplexity 50.82. The score
can be further improved, but that is not the point of this demo.

BERT model example can be launched similarly

python -m example.run_bert --save lm_model.h5 --model vanilla

but you will need to be patient. BERT easily achieves better performance
than GPT, but requires much more training time to converge.

[1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need"
[2]: https://arxiv.org/abs/1807.03819 "Universal Transformers"
[3]: https://arxiv.org/abs/1810.04805 "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding"
[4]: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf
"Improving Language Understanding by Generative Pre-Training"
1 change: 0 additions & 1 deletion example/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

23 changes: 16 additions & 7 deletions example/bpe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
from typing import (
Iterable, Tuple, List, NamedTuple, Dict, Sequence, FrozenSet, TextIO,
Callable)
Callable, Optional)


BPE_WORD_TAIL = '</w>'
Expand Down Expand Up @@ -112,14 +112,14 @@ class BPETokenizer:

def __init__(self, merges: BPEMerges,
word_tokenizer: Callable[[str], Iterable[str]],
mark_sequence_edges: bool=True):
mark_sequence_edges: bool = True):
self.word_tokenizer = word_tokenizer
self.merges = merges
self.bpe_cache = {}
self.mark_sequence_edges = mark_sequence_edges

def apply(self, text: str,
low_case: bool=True) -> Iterable[str]:
low_case: bool = True) -> Iterable[str]:
if self.mark_sequence_edges:
yield TOKEN_FOR_BEGINNING_OF_SEQUENCE
for token in self.word_tokenizer(text):
Expand All @@ -135,13 +135,20 @@ def apply(self, text: str,


class BPEVocabulary:
def __init__(self, bpe_vocabulary_file: TextIO):
def __init__(self, bpe_vocabulary_file: TextIO,
special_tokens: Optional[Sequence[str]] = None):
vocabulary = {
TOKEN_FOR_UNKNOWN: ID_FOR_UNKNOWN_TOKEN,
TOKEN_FOR_BEGINNING_OF_SEQUENCE: ID_FOR_BEGINNING_OF_SEQUENCE,
TOKEN_FOR_END_OF_SEQUENCE: ID_FOR_END_OF_SEQUENCE,
TOKEN_FOR_PADDING: ID_FOR_PADDING}
i = max(vocabulary.values()) + 1
if special_tokens is not None:
for extra_token in special_tokens:
if extra_token not in vocabulary:
vocabulary[extra_token] = i
i += 1
self.first_normal_token_id = i
assert i == len(vocabulary)
for line in bpe_vocabulary_file:
if line:
Expand All @@ -151,15 +158,18 @@ def __init__(self, bpe_vocabulary_file: TextIO):
i += 1
self.token_to_id = vocabulary
self.id_to_token = {v: k for k, v in vocabulary.items()}
self.last_normal_token_id = i - 1


class BPEEncoder:
"""
Converts a text into a stream of WordIDs and BPE tokens
"""
def __init__(self, bpe_tokenizer: BPETokenizer, bpe_vocabulary: TextIO):
def __init__(self, bpe_tokenizer: BPETokenizer, bpe_vocabulary: TextIO,
special_tokens: Optional[Sequence[str]] = None):
self.bpe_tokenizer = bpe_tokenizer
self.vocabulary = BPEVocabulary(bpe_vocabulary)
self.vocabulary = BPEVocabulary(
bpe_vocabulary, special_tokens=special_tokens)

def __call__(self, text: str) -> Iterable[Tuple[int, str]]:
token_to_id = self.vocabulary.token_to_id
Expand All @@ -176,4 +186,3 @@ def build_vocabulary(tokens: Iterable[str]) -> List[Tuple[str, int]]:
for token in tokens:
vocabulary[token.lower()] += 1
return sorted(vocabulary.items(), key=lambda i: (i[1], i[0]), reverse=True)

Loading

0 comments on commit 7e937ff

Please sign in to comment.