File size: 3,485 Bytes
a1e1bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ed5a9
 
a1e1bfa
 
 
 
 
 
 
 
 
 
ebe5d57
a1e1bfa
 
 
 
 
 
 
 
7e18dc0
 
 
 
a1e1bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe0594
a1e1bfa
 
 
ebe5d57
a1e1bfa
 
 
 
 
 
 
 
 
 
a82ed0c
6386ac4
64044be
6386ac4
 
 
a1e1bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc76327
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import CacheBackedEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
import chainlit as cl

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)

system_template = """
Use the following pieces of context to answer the user's question.
Please respond as an air-headed beach bro.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Example of your response should be:

```
The answer is foo
```

Begin!
----------------
{context}"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}

@cl.author_rename
def rename(orig_author: str):
    rename_dict = {"RetrievalQA": "Consulting The Kens"}
    return rename_dict.get(orig_author, orig_author)

@cl.on_chat_start
async def init():
    msg = cl.Message(content=f"Building Index...")
    await msg.send()

    # build FAISS index from csv
    loader = CSVLoader(file_path="./data/barbie.csv", source_column="Review_Url")
    data = loader.load()
    documents = text_splitter.transform_documents(data)
    store = LocalFileStore("./cache/")
    core_embeddings_model = OpenAIEmbeddings()
    embedder = CacheBackedEmbeddings.from_bytes_store(
        core_embeddings_model, store, namespace=core_embeddings_model.model
    )
    # make async docsearch
    docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)

    chain = RetrievalQA.from_chain_type(
        ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
        chain_type="stuff",
        return_source_documents=True,
        retriever=docsearch.as_retriever(),
        chain_type_kwargs = {"prompt": prompt}
    )

    msg.content = f"Index built!"
    await msg.send()

    cl.user_session.set("chain", chain)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    cb.answer_reached = True
    res = await chain.acall(message, callbacks=[cb])

    answer = res["result"]
    source_elements = []
    visited_sources = set()

    # Get the documents from the user session
    docs = res["source_documents"]
    metadatas = [doc.metadata for doc in docs]
    all_sources = [m["source"] for m in metadatas]

    for source in all_sources:
        if source in visited_sources:
            continue
        visited_sources.add(source)
        # Create the text element referenced in the message
        source_elements.append(
            cl.Text(content="https://www.imdb.com" + source, name="Review URL")
        )

    if source_elements:
        answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
    else:
        answer += "\nNo sources found"

    await cl.Message(content=answer, elements=source_elements).send()