Spaces:
Runtime error
Runtime error
from langchain import PromptTemplate, OpenAI, LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.docstore.document import Document | |
import chainlit as cl | |
from chainlit import user_session | |
import pandas as pd | |
persist_directory = "vector_db" | |
template = """Question: {question} | |
Answer: Let's think step by step.""" | |
# Get processed data from a json file | |
# PRODUCTS_DATA = pd.read_json('data/bestbuy-dataset-products.json').sample(n=3).to_dict(orient='records') | |
PRODUCTS_DATA = [] | |
def main(): | |
# Instantiate the chain for that user session | |
# prompt = PromptTemplate(template=template, input_variables=["question"]) | |
# llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0), verbose=True) | |
# Create a Chroma vector store | |
embeddings = OpenAIEmbeddings( | |
disallowed_special=(), | |
) | |
# products_data = [ | |
# {"sku":43900, "name":"Duracell - AAA Batteries (4-Pack)","product_spec_in_natural_language":"Product with name: Duracell - AAA Batteries (4-Pack) belongs to multiple categories: Connected Home & Housewares, Housewares, Household Batteries.\n Description of the product is following:\n product desctiption: Compatible with select electronic devices; AAA size; DURALOCK Power Preserve technology; 4-pack.\n\n Manufacturer of the product is Duracell and price is 5.49.\n ", "url": "a.com"}, | |
# {"sku":48530,"name":"Duracell - AA 1.5V CopperTop Batteries (4-Pack)","product_spec_in_natural_language":"Product with name: Duracell - AA 1.5V CopperTop Batteries (4-Pack) belongs to multiple categories: Connected Home & Housewares, Housewares, Household Batteries.\n Description of the product is following:\n product desctiption: Long-lasting energy; DURALOCK Power Preserve technology; for toys, clocks, radios, games, remotes, PDAs and more.\n\n Manufacturer of the product is Duracell and price is 5.49.\n ","url": "b.com"}, | |
# {"sku":127687,"name":"Duracell - AA Batteries (8-Pack)","product_spec_in_natural_language":"Product with name: Duracell - AA Batteries (8-Pack) belongs to multiple categories: Connected Home & Housewares, Housewares, Household Batteries.\n Description of the product is following:\n product desctiption: Compatible with select electronic devices; AA size; DURALOCK Power Preserve technology; 8-pack.\n\n Manufacturer of the product is Duracell and price is 7.49.\n ","url": "c.com"} | |
# ] | |
# products_data = pd.read_json('data/bestbuy-dataset-products.json').to_dict(orient='records') | |
PRODUCTS_DATA = pd.read_json('data/bestbuy-dataset-products.json').to_dict(orient='records') | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 1000, | |
chunk_overlap = 20, | |
length_function = len, | |
) | |
for item in PRODUCTS_DATA: | |
product_summary_data = item["product_spec_in_natural_language"] | |
docs = [ | |
Document(page_content=product_summary_data, | |
metadata={"source": item["sku"], "name": item["name"], "url": item['url'], "image": item["image"]}) | |
] | |
documents = text_splitter.split_documents(docs) | |
vectordb = Chroma.from_documents(documents=documents, embedding=embeddings, persist_directory=persist_directory) | |
vectordb.persist() | |
# chroma_data_collection= { | |
# # embeddings=[[1.2, 2.3, 4.5], [6.7, 8.2, 9.2]], | |
# documents: [products_data[0]["product_spec_in_natural_language"], products_data[1]["product_spec_in_natural_language"], products_data[2]["product_spec_in_natural_language"]], | |
# metadatas: [{"source": "43900"}, {"source": "48530"}, {"source": "127687"}], | |
# ids: ["43900", "48530", "127687"] | |
# } | |
# vectordb = None | |
# Create a chain that uses the Chroma vector store | |
chain = RetrievalQAWithSourcesChain.from_chain_type( | |
ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
temperature=0, | |
), | |
chain_type="stuff", | |
retriever=vectordb.as_retriever(), | |
return_source_documents=True, | |
) | |
# Store the chain in the user session | |
cl.user_session.set("llm_chain", chain) | |
async def main(message: str): | |
# Retrieve the chain from the user session | |
llm_chain = cl.user_session.get("llm_chain") # type: LLMChain | |
# Call the chain asynchronously | |
res = await llm_chain.acall(message, callbacks=[cl.AsyncLangchainCallbackHandler()]) | |
# Do any post processing here | |
print(res) | |
answer = res["answer"] | |
source_elements_dict = {} | |
source_elements = [] | |
for idx, source in enumerate(res["source_documents"]): | |
doc_id = source.metadata["source"] | |
# Get data using unique id of a product, so that we don't have to save | |
# unnecessary metadata in vecotor store | |
# product_df = pd.DataFrame(PRODUCTS_DATA) | |
# product = product_df.where(product_df['sku'] == f"{doc_id}") | |
# print('########', f"{doc_id}") | |
# print(product) | |
if doc_id not in source_elements_dict: | |
source_elements_dict[doc_id] = { | |
"url": source.metadata.get("url"), | |
"name": source.metadata.get("name"), | |
"image": source.metadata.get("image"), | |
} | |
for key, values in source_elements_dict.items(): | |
# product_links = ", ".join([str(x) for x in links]) | |
text_for_source = f"Product url: {values['url']}\n" | |
# if values["image"] is not None: | |
# source_elements.append(cl.Image(name="Image", display="inline", url=values["image"], size="small")) | |
# source_elements.append(cl.Text(name=values["name"], content=text_for_source, display="inline")) | |
source_elements = [ | |
# cl.Image(url=values["image"], name="image1", display="inline"), | |
cl.Text(content=text_for_source, name=values["name"], display="inline"), | |
] | |
not_found_indicators = ["not mentioned", "no mention", "not specified", "no information"] | |
if any([text in answer.lower() for text in not_found_indicators]): | |
# If product not found, do not show any product urls | |
source_elements = [] | |
# This varies from chain to chain, you should check which key to read. | |
await cl.Message(content=answer, elements=source_elements).send() | |