-
Notifications
You must be signed in to change notification settings - Fork 112
/
app.py
110 lines (92 loc) · 3.85 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Import streamlit for app dev
import streamlit as st
# Import transformer classes for generaiton
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
# Import torch for datatype attributes
import torch
# Import the prompt wrapper...but for llama index
from llama_index.prompts.prompts import SimpleInputPrompt
# Import the llama index HF Wrapper
from llama_index.llms import HuggingFaceLLM
# Bring in embeddings wrapper
from llama_index.embeddings import LangchainEmbedding
# Bring in HF embeddings - need these to represent document chunks
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
# Bring in stuff to change service context
from llama_index import set_global_service_context
from llama_index import ServiceContext
# Import deps to load documents
from llama_index import VectorStoreIndex, download_loader
from pathlib import Path
# Define variable to hold llama2 weights naming
name = "meta-llama/Llama-2-70b-chat-hf"
# Set auth token variable from hugging face
auth_token = "YOUR HUGGING FACE AUTH TOKEN HERE"
@st.cache_resource
def get_tokenizer_model():
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/', use_auth_token=auth_token)
# Create model
model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/'
, use_auth_token=auth_token, torch_dtype=torch.float16,
rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True)
return tokenizer, model
tokenizer, model = get_tokenizer_model()
# Create a system prompt
system_prompt = """<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as
helpfully as possible, while being safe. Your answers should not include
any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain
why instead of answering something not correct. If you don't know the answer
to a question, please don't share false information.
Your goal is to provide answers relating to the financial performance of
the company.<</SYS>>
"""
# Throw together the query wrapper
query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]")
# Create a HF LLM using the llama index wrapper
llm = HuggingFaceLLM(context_window=4096,
max_new_tokens=256,
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
model=model,
tokenizer=tokenizer)
# Create and dl embeddings instance
embeddings=LangchainEmbedding(
HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
)
# Create new service context instance
service_context = ServiceContext.from_defaults(
chunk_size=1024,
llm=llm,
embed_model=embeddings
)
# And set the service context
set_global_service_context(service_context)
# Download PDF Loader
PyMuPDFReader = download_loader("PyMuPDFReader")
# Create PDF Loader
loader = PyMuPDFReader()
# Load documents
documents = loader.load(file_path=Path('./data/annualreport.pdf'), metadata=True)
# Create an index - we'll be able to query this in a sec
index = VectorStoreIndex.from_documents(documents)
# Setup index query engine using LLM
query_engine = index.as_query_engine()
# Create centered main title
st.title('🦙 Llama Banker')
# Create a text input box for the user
prompt = st.text_input('Input your prompt here')
# If the user hits enter
if prompt:
response = query_engine.query(prompt)
# ...and write it out to the screen
st.write(response)
# Display raw response object
with st.expander('Response Object'):
st.write(response)
# Display source text
with st.expander('Source Text'):
st.write(response.get_formatted_sources())