from langchain.chat_models import ChatOpenAI from langchain.chains import ConversationalRetrievalChain from langchain.prompts import PromptTemplate import pickle import config from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever from memory import memory3 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.document_transformers import EmbeddingsRedundantFilter from langchain.retrievers.document_compressors import DocumentCompressorPipeline from langchain.text_splitter import CharacterTextSplitter from pydantic import BaseModel, Field from typing import Any, Optional, Dict, List from huggingface_hub import InferenceClient from langchain.llms.base import LLM import os chat_model_name = "HuggingFaceH4/zephyr-7b-alpha" reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1" hf_token = os.getenv("apiToken") kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True} class KwArgsModel(BaseModel): kwargs: Dict[str, Any] = Field(default_factory=dict) class CustomInferenceClient(LLM, KwArgsModel): model_name: str inference_client: InferenceClient def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): inference_client = InferenceClient(model=model_name, token=hf_token) super().__init__( model_name=model_name, hf_token=hf_token, kwargs=kwargs, inference_client=inference_client ) def _call( self, prompt: str, stop: Optional[List[str]] = None ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) response = ''.join(response_gen) return response @property def _llm_type(self) -> str: return "custom" @property def _identifying_params(self) -> dict: return {"model_name": self.model_name} chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs) reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs) prompt_template = config.DEFAULT_CHAT_TEMPLATE PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question", "chat_history"] ) chain_type_kwargs = {"prompt": PROMPT} embeddings = HuggingFaceEmbeddings() vectorstore = FAISS.load_local("cima_faiss_index", embeddings) retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5}) splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76) pipeline_compressor = DocumentCompressorPipeline( transformers=[splitter, redundant_filter, relevant_filter] ) compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) with open("docs_data.pkl", "rb") as file: docs = pickle.load(file) bm25_retriever = BM25Retriever.from_texts(docs) bm25_retriever.k = 2 bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. Chat History: {chat_history} Follow-up user message: {question} Rewritten user message:""" CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template) chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm, chain_type="stuff", retriever=ensemble_retriever, combine_docs_chain_kwargs=chain_type_kwargs, return_source_documents=True, get_chat_history=lambda h : h, condense_question_prompt=CUSTOM_QUESTION_PROMPT, memory=memory3, condense_question_llm = reform_llm )