sample_name / app.py
isayahc's picture
Update app.py
cc496ba
import os
from langchain.prompts import PromptTemplate
from langchain.llms import CTransformers
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.chains import RetrievalQA
import gradio as gr
# local_llm = "https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF/blob/main/zephyr-7b-beta.Q5_K_S.gguf"
# Load model directly
# from transformers import AutoModel
# local_llm = AutoModel.from_pretrained("TheBloke/zephyr-7B-beta-GGUF")
config = {
"max_new_token": 1024,
"repetition_penalty": 1.1,
"temperature": 0.1,
"top_k": 50,
"top_p": 0.9,
"stream": True,
"threads": int(os.cpu_count() / 2),
}
# local_llm = CTransformers(
# model = "TheBloke/zephyr-7B-beta-GGUF",
# model_file = "zephyr-7b-beta.Q4_0.gguf",
# model_type="mistral",
# lib="avx2", #for CPU use
# **config
# )
from ctransformers import AutoModelForCausalLM
llm_init = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GGUF")
# llm_init = CTransformers(model=local_llm, model_type="mistral", lib="avx2", **config)
prompt_template = """Use the following piece of information to answers the question asked by the user.
Don't try to make up the answer if you don't know the answer, simply say I don't know.
Context: {context}
Question: {question}
Only helpful answer below.
Helpful answer:
"""
model_name = "BAAI/bge-large-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
load_vector_store = Chroma(
persist_directory="stores/dino_cosine", embedding_function=embeddings
)
retriever = load_vector_store.as_retriever(search_kwargs={"k": 1})
# query = "How many genera of dinosaurs currently known?"
# semantic_search = retriever.get_relevant_documents(query)
# chain_type_kwargs = {"prompt": prompt}
# qa = RetrievalQA.from_chain_type(
# llm=llm_init,
# chain_type="stuff",
# retriever=retriever,
# verbose=True,
# chain_type_kwargs=chain_type_kwargs,
# return_source_documents=True,
# )
sample_query = [
"How many genera of dinosaurs currently known?",
"What methods are used to account for the incompleteness of the fossil record?",
"Were Dinosaurs in Decline Before the Cretaceous or Tertiary Boundary?",
]
def get_response(input):
query = input
chain_type_kwargs = {"prompt": prompt}
qa = RetrievalQA.from_chain_type(
llm=llm_init,
chain_type="stuff",
retriever=retriever,
verbose=True,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True,
)
response = qa(query)
return response
input = gr.Text(
label="Query",
show_label=True,
max_lines=2,
container=False,
placeholder="Enter your question",
)
gIface = gr.Interface(
fn=get_response,
inputs=input,
outputs="text",
title="Dinosaurs Diversity RAG AI",
description="RAG demo using Zephyr 7B Beta and Langchain",
examples=sample_query,
allow_flagging="never",
)
gIface.launch()
# llm_chain = LLMChain(prompt=prompt, llm=llm_init, verbose=True)