Spaces:
Sleeping
Sleeping
##################################### | |
## BitsAndBytes | |
##################################### | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded" | |
###### other models: | |
# "Trelis/Llama-2-7b-chat-hf-sharded-bf16" | |
# "bn22/Mistral-7B-Instruct-v0.1-sharded" | |
# "HuggingFaceH4/zephyr-7b-beta" | |
# function for loading 4-bit quantized model | |
def load_quantized_model(model_name: str): | |
""" | |
:param model_name: Name or path of the model to be loaded. | |
:return: Loaded quantized model. | |
""" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
load_in_4bit=True, | |
torch_dtype=torch.bfloat16, | |
quantization_config=bnb_config | |
) | |
return model | |
################################################## | |
## vs chat | |
################################################## | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
#from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.vectorstores.faiss import FAISS | |
from dotenv import load_dotenv | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
load_dotenv() | |
def get_vectorstore_from_url(url): | |
# get the text in document form | |
loader = WebBaseLoader(url) | |
document = loader.load() | |
# split the document into chunks | |
text_splitter = RecursiveCharacterTextSplitter() | |
document_chunks = text_splitter.split_documents(document) | |
####### | |
''' | |
FAISS | |
A FAISS vector store containing the embeddings of the text chunks. | |
''' | |
model = "BAAI/bge-base-en-v1.5" | |
encode_kwargs = { | |
"normalize_embeddings": True | |
} # set True to compute cosine similarity | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"} | |
) | |
# load from disk | |
vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) | |
#vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) | |
vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="./chroma_db") | |
print("-----") | |
print(vector_store.similarity_search("What is ALiBi?")) | |
print("-----") | |
####### | |
# create a vectorstore from the chunks | |
return vector_store | |
def get_context_retriever_chain(vector_store): | |
# specify model huggingface mode name | |
model_name = "anakin87/zephyr-7b-alpha-sharded" | |
# model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded" | |
###### other models: | |
# "Trelis/Llama-2-7b-chat-hf-sharded-bf16" | |
# "bn22/Mistral-7B-Instruct-v0.1-sharded" | |
# "HuggingFaceH4/zephyr-7b-beta" | |
# function for loading 4-bit quantized model | |
llm = load_quantized_model(model_name) | |
retriever = vector_store.as_retriever() | |
prompt = ChatPromptTemplate.from_messages([ | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation") | |
]) | |
retriever_chain = create_history_aware_retriever(llm, retriever, prompt) | |
return retriever_chain | |
def get_conversational_rag_chain(retriever_chain): | |
llm = load_quantized_model(model_name) | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "Answer the user's questions based on the below context:\n\n{context}"), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
]) | |
stuff_documents_chain = create_stuff_documents_chain(llm,prompt) | |
return create_retrieval_chain(retriever_chain, stuff_documents_chain) | |
def get_response(user_input): | |
retriever_chain = get_context_retriever_chain(st.session_state.vector_store) | |
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) | |
response = conversation_rag_chain.invoke({ | |
"chat_history": st.session_state.chat_history, | |
"input": user_query | |
}) | |
return response['answer'] | |
################### | |
################### | |
import gradio as gr | |
##from langchain_core.runnables.base import ChatPromptValue | |
#from torch import tensor | |
# Create Gradio interface | |
#vector_store = None # Set your vector store here | |
chat_history = [] # Set your chat history here | |
# Define your function here | |
def get_response(user_input): | |
# Define the prompt as a ChatPromptValue object | |
#user_input = ChatPromptValue(user_input) | |
# Convert the prompt to a tensor | |
#input_ids = user_input.tensor | |
#vs = get_vectorstore_from_url(user_url, all_domain) | |
vs = get_vectorstore_from_url("https://www.bofrost.de/shop/laenderkueche_5573/italienische-kueche_5576/linguine-mit-feinen-pilzen.html?position=1&clicked=") | |
print("------ here 22 " ) | |
chat_history =[] | |
retriever_chain = get_context_retriever_chain(vs) | |
conversation_rag_chain = get_conversational_rag_chain(retriever_chain) | |
response = conversation_rag_chain.invoke({ | |
"chat_history": chat_history, | |
"input": user_input | |
}) | |
return response['answer'] | |
def simple(text:str): | |
return text +" hhhmmm " | |
app = gr.Interface( | |
fn=get_response, | |
#fn=simple, | |
inputs=["text"], | |
outputs="text", | |
title="Chat with Websites", | |
description="Type your message and chat with websites.", | |
#allow_flagging=False | |
) | |
app.launch(debug=True, share=True)#wie registriere ich mich bei bofrost? Was kosten Linguine |