Skip to content

Commit

Permalink
seq2seq has funcs with same name and it's not clear. renamed for clar…
Browse files Browse the repository at this point in the history
…ity.
  • Loading branch information
bckenstler committed Aug 16, 2017
1 parent 20e8524 commit 60754a3
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions deeplearning2/seq2seq-translation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"\n",
"[This question on Open Data Stack Exchange](http://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages) pointed me to the open translation site http://tatoeba.org/ which has downloads available at http://tatoeba.org/eng/downloads - and better yet, someone did the extra work of splitting language pairs into individual text files here: http://www.manythings.org/anki/\n",
"\n",
"The English to French pairs are too big to include in the repo, so download to `data/eng-fra.txt` before continuing. The file is a tab separated list of translation pairs:\n",
"The English to French pairs are too big to include in the repo, so download to `data/fra.txt` before continuing. The file is a tab separated list of translation pairs:\n",
"\n",
"```\n",
"I am cold. Je suis froid.\n",
Expand Down Expand Up @@ -168,11 +168,11 @@
},
"outputs": [],
"source": [
"def readLangs(lang1, lang2, reverse=False):\n",
"def readLangs(lang1, lang2, pairs_file, reverse=False):\n",
" print(\"Reading lines...\")\n",
"\n",
" # Read the file and split into lines\n",
" lines = open('data/%s-%s.txt' % (lang1, lang2)).read().strip().split('\\n')\n",
" lines = open('data/%s' % (pairs_file)).read().strip().split('\\n')\n",
" \n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
Expand Down Expand Up @@ -268,8 +268,8 @@
}
],
"source": [
"def prepareData(lang1, lang2, reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
"def prepareData(lang1, lang2, pairs_file, reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(lang1, lang2, pairs_file, reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
Expand All @@ -282,7 +282,7 @@
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n",
"input_lang, output_lang, pairs = prepareData('eng', 'fra', 'fra.txt', True)\n",
"print(random.choice(pairs))"
]
},
Expand Down Expand Up @@ -542,8 +542,7 @@
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true,
"heading_collapsed": true
"editable": true
},
"source": [
"## Training\n",
Expand All @@ -559,8 +558,7 @@
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"hidden": true
"editable": true
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -597,8 +595,7 @@
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"hidden": true
"editable": true
},
"outputs": [],
"source": [
Expand All @@ -621,8 +618,7 @@
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"hidden": true
"editable": true
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -664,9 +660,7 @@
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true,
"heading_collapsed": true,
"hidden": true
"editable": true
},
"source": [
"### Attention"
Expand All @@ -678,15 +672,14 @@
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"hidden": true
"editable": true
},
"outputs": [],
"source": [
"# TODO: Make this change during training\n",
"teacher_forcing_ratio = 0.5\n",
"\n",
"def train(input_variable, target_variable, encoder, decoder, encoder_optimizer, \n",
"def attn_train(input_variable, target_variable, encoder, decoder, encoder_optimizer, \n",
" decoder_optimizer, criterion, max_length=MAX_LENGTH):\n",
" encoder_hidden = encoder.initHidden()\n",
"\n",
Expand Down Expand Up @@ -739,8 +732,7 @@
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true,
"heading_collapsed": true
"editable": true
},
"source": [
"# Plotting results\n",
Expand All @@ -754,8 +746,7 @@
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"hidden": true
"editable": true
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1002,6 +993,13 @@
"You could simply run `plt.matshow(attentions)` to see attention output displayed as a matrix, with the columns being input steps and rows being output steps:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"NOTE: This only works when using the attentional decoder, if you've been following the notebook to this point you are using the standard decoder."
]
},
{
"cell_type": "code",
"execution_count": 20,
Expand Down

0 comments on commit 60754a3

Please sign in to comment.