Oysiyl commited on
Commit
a716434
1 Parent(s): d0fb496

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +192 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ from threading import Thread
4
+
5
+ import torch
6
+ import gradio as gr
7
+ from langchain.llms.base import LLM
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_community.vectorstores import Pinecone
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline
15
+
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ def initialize_model_and_tokenizer(model_name="mistralai/Mistral-7B-Instruct-v0.2"):
20
+ quantization_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_compute_dtype=torch.float16,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_use_double_quant=True,
25
+ )
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float16,
30
+ device_map='auto',
31
+ quantization_config=quantization_config
32
+ )
33
+ model.eval()
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ return model, tokenizer
37
+
38
+ def init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty):
39
+ class CustomLLM(LLM):
40
+
41
+ """Streamer Object"""
42
+
43
+ streamer: Optional[TextIteratorStreamer] = None
44
+
45
+ def _call(self, prompt, stop=None, run_manager=None) -> str:
46
+ self.streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
47
+ inputs = tokenizer(prompt, return_tensors="pt")
48
+ input_ids = inputs["input_ids"].to('cuda')
49
+ generate_kwargs = dict(
50
+ temperature=float(temp),
51
+ max_new_tokens=int(max_new_tokens),
52
+ top_p=float(top_p),
53
+ top_k=int(top_k),
54
+ repetition_penalty=float(r_penalty),
55
+ do_sample=True
56
+ )
57
+ kwargs = dict(input_ids=input_ids, streamer=self.streamer, **generate_kwargs)
58
+ thread = Thread(target=model.generate, kwargs=kwargs)
59
+ thread.start()
60
+ return ""
61
+
62
+ @property
63
+ def _llm_type(self) -> str:
64
+ return "custom"
65
+
66
+ llm = CustomLLM()
67
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
68
+ questionprompt = PromptTemplate.from_template(
69
+ """<s>[INST]
70
+ Use the following pieces of context to answer the question at the end.
71
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
72
+ CONTEXT: {context}
73
+ CHAT HISTORY: {chat_history}
74
+ QUESTION: {question}
75
+ Helpful Answer:
76
+ [/INST]
77
+ """
78
+ )
79
+ llm_chain = ConversationalRetrievalChain.from_llm(
80
+ llm=llm,
81
+ retriever=db.as_retriever(search_kwargs={"k": 5}),
82
+ memory=memory,
83
+ condense_question_prompt=questionprompt,
84
+ )
85
+
86
+ return llm_chain, llm
87
+
88
+ index_name = "resume-demo"
89
+
90
+ queries = [["Which masters degree Dmytro Kisil has?"],
91
+ ["Which amount of salary does Dmytro Kisil is looking for?"],
92
+ ["How long does Dmytro Kisil looking for a job?"],
93
+ ["Why Dmytro Kisil moved to Netherlands?"],
94
+ ["When Dmytro Kisil left Ukraine?"],
95
+ ["Where Dmytro Kisil live now?"],
96
+ ["How much years of working experience in total Dmytro Kisil has?"],
97
+ ["How fast Dmytro Kisil can start working for my company?"]]
98
+
99
+ embed = HuggingFaceBgeEmbeddings(model_name='BAAI/bge-small-en-v1.5')
100
+
101
+ db = Pinecone.from_existing_index(index_name, embed)
102
+
103
+ model, tokenizer = initialize_model_and_tokenizer(model_name="mistralai/Mistral-7B-Instruct-v0.2")
104
+
105
+ with gr.Blocks() as demo:
106
+ with gr.Column():
107
+ chatbot = gr.Chatbot()
108
+ with gr.Row():
109
+ msg = gr.Textbox(scale=9)
110
+ submit_b = gr.Button("Submit", scale=1)
111
+ with gr.Row():
112
+ retry_b = gr.Button("Retry")
113
+ undo_b = gr.Button("Undo")
114
+ clear_b = gr.Button("Clear")
115
+ examples = gr.Examples(queries, msg)
116
+ with gr.Accordion("Additional options", open=False):
117
+ temp = gr.Slider(
118
+ label="Temperature",
119
+ value=0.01,
120
+ minimum=0.01,
121
+ maximum=1.00,
122
+ step=0.01,
123
+ interactive=True,
124
+ info="Higher values produce more diverse outputs",
125
+ )
126
+ max_new_tokens = gr.Slider(
127
+ label="Max new tokens",
128
+ value=1024,
129
+ minimum=64,
130
+ maximum=8192,
131
+ step=64,
132
+ interactive=True,
133
+ info="The maximum numbers of new tokens",
134
+ )
135
+ top_p = gr.Slider(
136
+ label="Top-p (nucleus sampling)",
137
+ value=0.95,
138
+ minimum=0.00,
139
+ maximum=1.00,
140
+ step=0.01,
141
+ interactive=True,
142
+ info="Higher values sample more low-probability tokens",
143
+ )
144
+ top_k = gr.Slider(
145
+ label="Top-k",
146
+ value=40,
147
+ minimum=0,
148
+ maximum=100,
149
+ step=1,
150
+ interactive=True,
151
+ info="select from top 0 tokens (because zero, relies on top_p)",
152
+ )
153
+ r_penalty = gr.Slider(
154
+ label="Repetition penalty",
155
+ value=1.15,
156
+ minimum=1.0,
157
+ maximum=2.0,
158
+ step=0.01,
159
+ interactive=True,
160
+ info="Penalize repeated tokens",
161
+ )
162
+
163
+ def user(user_message, history):
164
+ return "", history + [[user_message, None]]
165
+
166
+ def undo(history):
167
+ return history[:-1].copy()
168
+
169
+ def retry(user_message, history):
170
+ try:
171
+ prev_user_message = history[-1][0]
172
+ except:
173
+ prev_user_message = ""
174
+ return prev_user_message, history + [[prev_user_message, None]]
175
+
176
+ def bot(history, temp, max_new_tokens, top_p, top_k, r_penalty):
177
+ llm_chain, llm = init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty)
178
+ llm_chain.run(question=history[-1][0])
179
+ history[-1][1] = ""
180
+ for character in llm.streamer:
181
+ history[-1][1] += character
182
+ yield history
183
+ llm_chain, llm = init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty)
184
+
185
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot)
186
+ submit_b.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot)
187
+ retry_b.click(retry, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot)
188
+ clear_b.click(lambda: None, None, chatbot, queue=False)
189
+ undo_b.click(undo, chatbot, chatbot, queue=False)
190
+
191
+ demo.queue()
192
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ langchain
3
+ transformers
4
+ torch