cima-free-chat / query_data.py
ethanrom's picture
Update query_data.py
9775192
raw
history blame
No virus
4.98 kB
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
import pickle
import config
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
from memory import memory3
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
import os
chat_model_name = "HuggingFaceH4/zephyr-7b-alpha"
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
hf_token = os.getenv("apiToken")
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs)
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs)
prompt_template = config.DEFAULT_CHAT_TEMPLATE
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question", "chat_history"]
)
chain_type_kwargs = {"prompt": PROMPT}
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("cima_faiss_index", embeddings)
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5})
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, redundant_filter, relevant_filter]
)
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
with open("docs_data.pkl", "rb") as file:
docs = pickle.load(file)
bm25_retriever = BM25Retriever.from_texts(docs)
bm25_retriever.k = 2
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST].
Chat History:
{chat_history}
Follow-up user message: {question}
Rewritten user message:"""
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)
chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm,
chain_type="stuff",
retriever=ensemble_retriever,
combine_docs_chain_kwargs=chain_type_kwargs,
return_source_documents=True,
get_chat_history=lambda h : h,
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
memory=memory3,
condense_question_llm = reform_llm
)