import streamlit as st import os from io import StringIO from llama_index.llms import HuggingFaceInferenceAPI from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding from llama_index import ServiceContext, VectorStoreIndex from llama_index.schema import Document import uuid from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter inference_api_key = st.secrets["INFRERENCE_API_TOKEN"] embed_model_name = st.text_input( 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce") llm_model_name = st.text_input( 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2") query = st.text_input( 'Query', "What is the price of the product?") html_file = st.file_uploader("Upload a html file", type=["html"]) if st.button('Start Pipeline'): if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None: st.write('Running Pipeline') llm = HuggingFaceInferenceAPI( model_name=llm_model_name, token=inference_api_key) embed_model = HuggingFaceInferenceAPIEmbedding( model_name=embed_model_name, token=inference_api_key, model_kwargs={"device": ""}, encode_kwargs={"normalize_embeddings": True}, ) service_context = ServiceContext.from_defaults( embed_model=embed_model, llm=llm) stringio = StringIO(html_file.getvalue().decode("utf-8")) string_data = stringio.read() with st.expander("Uploaded HTML"): st.write(string_data) document_id = str(uuid.uuid4()) document = Document(text=string_data) document.metadata["id"] = document_id documents = [document] filters = MetadataFilters( filters=[ExactMatchFilter(key="id", value=document_id)]) index = VectorStoreIndex.from_documents( documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) retriever = index.as_retriever() ranked_nodes = retriever.retrieve( query) with st.expander("Ranked Nodes"): for node in ranked_nodes: st.write(node.node.get_content(), "-> Score:", node.score) query_engine = index.as_query_engine( filters=filters, service_context=service_context) response = query_engine.query(query) st.write(response) else: st.error('Please fill in all the fields') else: st.write('Press start to begin') # if html_file is not None: # stringio = StringIO(html_file.getvalue().decode("utf-8")) # string_data = stringio.read() # with st.expander("Uploaded HTML"): # st.write(string_data) # document_id = str(uuid.uuid4()) # document = Document(text=string_data) # document.metadata["id"] = document_id # documents = [document] # filters = MetadataFilters( # filters=[ExactMatchFilter(key="id", value=document_id)]) # index = VectorStoreIndex.from_documents( # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) # retriever = index.as_retriever() # ranked_nodes = retriever.retrieve( # "Get me all the information about the product") # with st.expander("Ranked Nodes"): # for node in ranked_nodes: # st.write(node.node.get_content(), "-> Score:", node.score) # query_engine = index.as_query_engine( # filters=filters, service_context=service_context) # response = query_engine.query( # "Get me all the information about the product") # st.write(response)