Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ You can then run the `interact.py` script on the pretrained model:
python3 interact.py --model models/
```

After running `interact.py` you can easily exit/stop talking with bot by typing:

```bash
quit
```

## Pretrained model

We make a pretrained and fine-tuned model available on our S3 [here](https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/finetuned_chatbot_gpt.tar.gz). The easiest way to download and use this model is just to run the `interact.py` script to talk with the model. Without any argument, this script will automatically download and cache our model.
Expand Down
12 changes: 8 additions & 4 deletions interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from pprint import pformat
import warnings

from termcolor import colored
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -136,18 +136,22 @@ def run():
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality)))

history = []
print("\n===============================================")
print("================== Conv AI ====================")
print("===============================================\n")
while True:
raw_text = input(">>> ")
raw_text = input(colored("You: ",'green'))
if raw_text == 'quit': break
while not raw_text:
print('Prompt should not be empty!')
raw_text = input(">>> ")
raw_text = input(colored("You: ",'green'))
history.append(tokenizer.encode(raw_text))
with torch.no_grad():
out_ids = sample_sequence(personality, history, tokenizer, model, args)
history.append(out_ids)
history = history[-(2*args.max_history+1):]
out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
print(out_text)
print(colored("Bot:",'red'),out_text)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ transformers==2.5.1
tensorboardX==1.8
tensorflow # for tensorboardX
spacy
termcolor