Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional | |
from threading import Thread | |
import torch | |
import gradio as gr | |
from langchain.llms.base import LLM | |
from langchain.prompts import PromptTemplate | |
from langchain_community.vectorstores import Pinecone | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def initialize_model_and_tokenizer(model_name="mistralai/Mistral-7B-Instruct-v0.2"): | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
device_map='auto', | |
quantization_config=quantization_config | |
) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
return model, tokenizer | |
def init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty): | |
class CustomLLM(LLM): | |
"""Streamer Object""" | |
streamer: Optional[TextIteratorStreamer] = None | |
def _call(self, prompt, stop=None, run_manager=None) -> str: | |
self.streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to('cuda') | |
generate_kwargs = dict( | |
temperature=float(temp), | |
max_new_tokens=int(max_new_tokens), | |
top_p=float(top_p), | |
top_k=int(top_k), | |
repetition_penalty=float(r_penalty), | |
do_sample=True | |
) | |
kwargs = dict(input_ids=input_ids, streamer=self.streamer, **generate_kwargs) | |
thread = Thread(target=model.generate, kwargs=kwargs) | |
thread.start() | |
return "" | |
def _llm_type(self) -> str: | |
return "custom" | |
llm = CustomLLM() | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
questionprompt = PromptTemplate.from_template( | |
"""<s>[INST] | |
Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
CONTEXT: {context} | |
CHAT HISTORY: {chat_history} | |
QUESTION: {question} | |
Helpful Answer: | |
[/INST] | |
""" | |
) | |
llm_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=db.as_retriever(search_kwargs={"k": 5}), | |
memory=memory, | |
condense_question_prompt=questionprompt, | |
) | |
return llm_chain, llm | |
index_name = "resume-demo" | |
queries = [["Which masters degree Dmytro Kisil has?"], | |
["Which amount of salary does Dmytro Kisil is looking for?"], | |
["How long does Dmytro Kisil looking for a job?"], | |
["Why Dmytro Kisil moved to Netherlands?"], | |
["When Dmytro Kisil left Ukraine?"], | |
["Where Dmytro Kisil live now?"], | |
["How much years of working experience in total Dmytro Kisil has?"], | |
["How fast Dmytro Kisil can start working for my company?"]] | |
embed = HuggingFaceBgeEmbeddings(model_name='BAAI/bge-small-en-v1.5') | |
db = Pinecone.from_existing_index(index_name, embed) | |
model, tokenizer = initialize_model_and_tokenizer(model_name="mistralai/Mistral-7B-Instruct-v0.2") | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
msg = gr.Textbox(scale=9) | |
submit_b = gr.Button("Submit", scale=1) | |
with gr.Row(): | |
retry_b = gr.Button("Retry") | |
undo_b = gr.Button("Undo") | |
clear_b = gr.Button("Clear") | |
examples = gr.Examples(queries, msg) | |
with gr.Accordion("Additional options", open=False): | |
temp = gr.Slider( | |
label="Temperature", | |
value=0.01, | |
minimum=0.01, | |
maximum=1.00, | |
step=0.01, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=1024, | |
minimum=64, | |
maximum=8192, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.95, | |
minimum=0.00, | |
maximum=1.00, | |
step=0.01, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
top_k = gr.Slider( | |
label="Top-k", | |
value=40, | |
minimum=0, | |
maximum=100, | |
step=1, | |
interactive=True, | |
info="select from top 0 tokens (because zero, relies on top_p)", | |
) | |
r_penalty = gr.Slider( | |
label="Repetition penalty", | |
value=1.15, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.01, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def undo(history): | |
return history[:-1].copy() | |
def retry(user_message, history): | |
try: | |
prev_user_message = history[-1][0] | |
except: | |
prev_user_message = "" | |
return prev_user_message, history + [[prev_user_message, None]] | |
def bot(history, temp, max_new_tokens, top_p, top_k, r_penalty): | |
llm_chain, llm = init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty) | |
llm_chain.run(question=history[-1][0]) | |
history[-1][1] = "" | |
for character in llm.streamer: | |
history[-1][1] += character | |
yield history | |
llm_chain, llm = init_chain(model, tokenizer, db, embed, temp, max_new_tokens, top_p, top_k, r_penalty) | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot) | |
submit_b.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot) | |
retry_b.click(retry, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot, temp, max_new_tokens, top_p, top_k, r_penalty], chatbot) | |
clear_b.click(lambda: None, None, chatbot, queue=False) | |
undo_b.click(undo, chatbot, chatbot, queue=False) | |
demo.queue() | |
demo.launch(share=True) |