isayahc's picture
Update app.py
e1b8370
import os
import gradio as gr
import boto3
from botocore import UNSIGNED
from botocore.client import Config
import torch
from huggingface_hub import AsyncInferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFaceHub
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import ChatPromptTemplate
from langchain.document_loaders import WebBaseLoader
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms import CTransformers
from transformers import AutoModel
from typing import Iterator
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)
embeddings = HuggingFaceHubEmbeddings()
model_id = "TheBloke/zephyr-7B-beta-GGUF"
# model_id = "HuggingFaceH4/zephyr-7b-beta"
# model_id = "meta-llama/Llama-2-7b-chat-hf"
# model = AutoModelForCausalLM.from_pretrained(
# model_id,
# device_map="auto",
# low_cpu_mem_usage=True
# )
# print( "initalized model")
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(model_id)
# model = AutoModel.from_pretrained("TheBloke/zephyr-7B-beta-GGUF")
device = "cpu"
# llm_model = CTransformers(
# model="TheBloke/zephyr-7B-beta-GGUF",
# model_type="mistral",
# max_new_tokens=4384,
# temperature=0.2,
# repetition_penalty=1.13,
# device=device # Set the device explicitly during model initialization
# )
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(model_id)
# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10)
# hf = HuggingFacePipeline(pipeline=pipe)
print( "initalized model")
# tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
s3.download_file('rad-rag-demos', 'vectorstores/chroma.sqlite3', './chroma_db/chroma.sqlite3')
db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
db.get()
retriever = db.as_retriever()
global qa
qa = RetrievalQA.from_chain_type(llm=llm_model, chain_type="stuff", retriever=retriever, return_source_documents=True)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def bot(history):
response = infer(history[-1][0])
history[-1][1] = response['result']
return history
def infer(question):
query = question
result = qa({"query": query})
return result
css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""
title = """
<div style="text-align: center;max-width: 700px;">
<h1>Chat with PDF</h1>
<p style="text-align: center;">Upload a .PDF from your computer, click the "Load PDF to LangChain" button, <br />
when everything is ready, you can start asking questions about the pdf ;)</p>
</div>
"""
# with gr.Blocks(css=css) as demo:
# with gr.Column(elem_id="col-container"):
# gr.HTML(title)
# chatbot = gr.Chatbot([], elem_id="chatbot")
# with gr.Row():
# question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
# question.submit(add_text, [chatbot, question], [chatbot, question]).then(
# bot, chatbot, chatbot
# )
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
)
with gr.Blocks(css="style.css") as demo:
# gr.Markdown(DESCRIPTION)
# gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
# gr.Markdown(LICENSE)
#x = 0
if __name__ == "__main__":
demo.launch()