Spaces:
Sleeping
Sleeping
File size: 5,334 Bytes
8b6399b 2145acc 8b6399b 6ad144b 8b6399b 6ad144b 2145acc 6ad144b 8b6399b 16dcc46 6ad144b 16dcc46 8b6399b 6ad144b 2145acc 6ad144b 2145acc 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 242bba0 6ad144b 5b24b6b 6ad144b 242bba0 6ad144b 242bba0 6ad144b 8b6399b 6ad144b 8b6399b 6ad144b 8b6399b 6ad144b 8b6399b 6ad144b 8b6399b 6ad144b df26c41 6ad144b df26c41 6ad144b df26c41 6ad144b 8b6399b 6ad144b 8b6399b 6ad144b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import streamlit as st
import os
from io import StringIO
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import Document
import uuid
from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
from typing import List
from pydantic import BaseModel
inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
# embed_model_name = st.text_input(
# 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
# llm_model_name = st.text_input(
# 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
class PriceModel(BaseModel):
"""Data model for price"""
price: str
embed_model_name = "jinaai/jina-embedding-s-en-v1"
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
llm = HuggingFaceInferenceAPI(
model_name=llm_model_name, token=inference_api_key)
embed_model = HuggingFaceInferenceAPIEmbedding(
model_name=embed_model_name,
token=inference_api_key,
model_kwargs={"device": ""},
encode_kwargs={"normalize_embeddings": True},
)
service_context = ServiceContext.from_defaults(
embed_model=embed_model, llm=llm)
query = st.text_input(
'Query', "What is the price of the product?"
)
html_file = st.file_uploader("Upload a html file", type=["html"])
if html_file is not None:
stringio = StringIO(html_file.getvalue().decode("utf-8"))
string_data = stringio.read()
with st.expander("Uploaded HTML"):
st.write(string_data)
document_id = str(uuid.uuid4())
document = Document(text=string_data)
document.metadata["id"] = document_id
documents = [document]
filters = MetadataFilters(
filters=[ExactMatchFilter(key="id", value=document_id)])
index = VectorStoreIndex.from_documents(
documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
query_engine = index.as_query_engine(
filters=filters, service_context=service_context, response_mode="tree_summarize", output_cls=PriceModel)
response = query_engine.query(query)
st.write(response.response)
st.write(f'Price: {response.price}')
# if st.button('Start Pipeline'):
# if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
# st.write('Running Pipeline')
# llm = HuggingFaceInferenceAPI(
# model_name=llm_model_name, token=inference_api_key)
# embed_model = HuggingFaceInferenceAPIEmbedding(
# model_name=embed_model_name,
# token=inference_api_key,
# model_kwargs={"device": ""},
# encode_kwargs={"normalize_embeddings": True},
# )
# service_context = ServiceContext.from_defaults(
# embed_model=embed_model, llm=llm)
# stringio = StringIO(html_file.getvalue().decode("utf-8"))
# string_data = stringio.read()
# with st.expander("Uploaded HTML"):
# st.write(string_data)
# document_id = str(uuid.uuid4())
# document = Document(text=string_data)
# document.metadata["id"] = document_id
# documents = [document]
# filters = MetadataFilters(
# filters=[ExactMatchFilter(key="id", value=document_id)])
# index = VectorStoreIndex.from_documents(
# documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
# retriever = index.as_retriever()
# ranked_nodes = retriever.retrieve(
# query)
# with st.expander("Ranked Nodes"):
# for node in ranked_nodes:
# st.write(node.node.get_content(), "-> Score:", node.score)
# query_engine = index.as_query_engine(
# filters=filters, service_context=service_context)
# response = query_engine.query(query)
# st.write(response.response)
# st.write(response.source_nodes)
# else:
# st.error('Please fill in all the fields')
# else:
# st.write('Press start to begin')
# # if html_file is not None:
# # stringio = StringIO(html_file.getvalue().decode("utf-8"))
# # string_data = stringio.read()
# # with st.expander("Uploaded HTML"):
# # st.write(string_data)
# # document_id = str(uuid.uuid4())
# # document = Document(text=string_data)
# # document.metadata["id"] = document_id
# # documents = [document]
# # filters = MetadataFilters(
# # filters=[ExactMatchFilter(key="id", value=document_id)])
# # index = VectorStoreIndex.from_documents(
# # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
# # retriever = index.as_retriever()
# # ranked_nodes = retriever.retrieve(
# # "Get me all the information about the product")
# # with st.expander("Ranked Nodes"):
# # for node in ranked_nodes:
# # st.write(node.node.get_content(), "-> Score:", node.score)
# # query_engine = index.as_query_engine(
# # filters=filters, service_context=service_context)
# # response = query_engine.query(
# # "Get me all the information about the product")
# # st.write(response)
|