File size: 5,051 Bytes
36c0029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
import pickle
import config
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
from memory import memory3
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM

import os
os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY

chat_model_name = "HuggingFaceH4/zephyr-7b-alpha"
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
hf_token = "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"
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True}

class KwArgsModel(BaseModel):
    kwargs: Dict[str, Any] = Field(default_factory=dict)

class CustomInferenceClient(LLM, KwArgsModel):
    model_name: str
    inference_client: InferenceClient

    def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
        inference_client = InferenceClient(model=model_name, token=hf_token)
        super().__init__(
            model_name=model_name,
            hf_token=hf_token,
            kwargs=kwargs,
            inference_client=inference_client
        )

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
        response = ''.join(response_gen)  
        return response

    @property
    def _llm_type(self) -> str:
        return "custom"

    @property
    def _identifying_params(self) -> dict:
        return {"model_name": self.model_name}
    

chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs)
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs)



prompt_template = config.DEFAULT_CHAT_TEMPLATE

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question", "chat_history"]
)


chain_type_kwargs = {"prompt": PROMPT}

embeddings = OpenAIEmbeddings()
vectorstore = FAISS.load_local("cima_faiss_index", embeddings)

retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5})


splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
pipeline_compressor = DocumentCompressorPipeline(
    transformers=[splitter, redundant_filter, relevant_filter]
)

compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)

with open("docs_data.pkl", "rb") as file:
    docs = pickle.load(file)

bm25_retriever = BM25Retriever.from_texts(docs)
bm25_retriever.k = 2

bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)

ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5])


custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST].

Chat History:
{chat_history}

Follow-up user message: {question}
Rewritten user message:"""

CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)


chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm, 
                                             chain_type="stuff", 
                                             retriever=ensemble_retriever, 
                                             combine_docs_chain_kwargs=chain_type_kwargs,
                                             return_source_documents=True,
                                             get_chat_history=lambda h : h,
                                             condense_question_prompt=CUSTOM_QUESTION_PROMPT,
                                             memory=memory3,
                                             condense_question_llm = reform_llm
                                             )