File size: 5,334 Bytes
8b6399b
 
 
 
 
 
 
 
 
2145acc
 
8b6399b
 
 
6ad144b
 
8b6399b
6ad144b
 
 
2145acc
 
 
 
 
 
6ad144b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b6399b
16dcc46
6ad144b
 
16dcc46
8b6399b
 
6ad144b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2145acc
6ad144b
 
 
 
 
2145acc
 
6ad144b
 
 
 
 
242bba0
6ad144b
 
 
 
 
 
242bba0
6ad144b
 
242bba0
6ad144b
 
 
 
242bba0
6ad144b
242bba0
6ad144b
 
 
242bba0
6ad144b
 
242bba0
6ad144b
 
242bba0
6ad144b
242bba0
6ad144b
 
242bba0
6ad144b
 
 
242bba0
6ad144b
 
242bba0
6ad144b
242bba0
6ad144b
5b24b6b
6ad144b
242bba0
6ad144b
 
 
 
242bba0
6ad144b
 
 
 
 
8b6399b
6ad144b
8b6399b
6ad144b
 
 
8b6399b
6ad144b
 
8b6399b
6ad144b
 
8b6399b
6ad144b
df26c41
6ad144b
 
df26c41
6ad144b
 
 
df26c41
6ad144b
 
8b6399b
6ad144b
 
8b6399b
6ad144b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import streamlit as st
import os
from io import StringIO
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import Document
import uuid
from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
from typing import List
from pydantic import BaseModel

inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]

# embed_model_name = st.text_input(
#     'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")

# llm_model_name = st.text_input(
#     'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")


class PriceModel(BaseModel):
    """Data model for price"""
    price: str


embed_model_name = "jinaai/jina-embedding-s-en-v1"
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"

llm = HuggingFaceInferenceAPI(
    model_name=llm_model_name, token=inference_api_key)

embed_model = HuggingFaceInferenceAPIEmbedding(
    model_name=embed_model_name,
    token=inference_api_key,
    model_kwargs={"device": ""},
    encode_kwargs={"normalize_embeddings": True},
)

service_context = ServiceContext.from_defaults(
    embed_model=embed_model, llm=llm)

query = st.text_input(
    'Query', "What is the price of the product?"
)

html_file = st.file_uploader("Upload a html file", type=["html"])

if html_file is not None:
    stringio = StringIO(html_file.getvalue().decode("utf-8"))
    string_data = stringio.read()
    with st.expander("Uploaded HTML"):
        st.write(string_data)

    document_id = str(uuid.uuid4())

    document = Document(text=string_data)
    document.metadata["id"] = document_id
    documents = [document]

    filters = MetadataFilters(
        filters=[ExactMatchFilter(key="id", value=document_id)])

    index = VectorStoreIndex.from_documents(
        documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

    query_engine = index.as_query_engine(
        filters=filters, service_context=service_context, response_mode="tree_summarize", output_cls=PriceModel)

    response = query_engine.query(query)

    st.write(response.response)

    st.write(f'Price: {response.price}')

# if st.button('Start Pipeline'):
#     if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
#         st.write('Running Pipeline')
#         llm = HuggingFaceInferenceAPI(
#             model_name=llm_model_name, token=inference_api_key)

#         embed_model = HuggingFaceInferenceAPIEmbedding(
#             model_name=embed_model_name,
#             token=inference_api_key,
#             model_kwargs={"device": ""},
#             encode_kwargs={"normalize_embeddings": True},
#         )

#         service_context = ServiceContext.from_defaults(
#             embed_model=embed_model, llm=llm)

#         stringio = StringIO(html_file.getvalue().decode("utf-8"))
#         string_data = stringio.read()
#         with st.expander("Uploaded HTML"):
#             st.write(string_data)

#         document_id = str(uuid.uuid4())

#         document = Document(text=string_data)
#         document.metadata["id"] = document_id
#         documents = [document]

#         filters = MetadataFilters(
#             filters=[ExactMatchFilter(key="id", value=document_id)])

#         index = VectorStoreIndex.from_documents(
#             documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

#         retriever = index.as_retriever()

#         ranked_nodes = retriever.retrieve(
#             query)

#         with st.expander("Ranked Nodes"):
#             for node in ranked_nodes:
#                 st.write(node.node.get_content(), "-> Score:", node.score)

#         query_engine = index.as_query_engine(
#             filters=filters, service_context=service_context)

#         response = query_engine.query(query)

#         st.write(response.response)

#         st.write(response.source_nodes)

#     else:
#         st.error('Please fill in all the fields')
# else:
#     st.write('Press start to begin')

# # if html_file is not None:
# #     stringio = StringIO(html_file.getvalue().decode("utf-8"))
# #     string_data = stringio.read()
# #     with st.expander("Uploaded HTML"):
# #         st.write(string_data)

# #     document_id = str(uuid.uuid4())

# #     document = Document(text=string_data)
# #     document.metadata["id"] = document_id
# #     documents = [document]

# #     filters = MetadataFilters(
# #         filters=[ExactMatchFilter(key="id", value=document_id)])

# #     index = VectorStoreIndex.from_documents(
# #         documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

# #     retriever = index.as_retriever()

# #     ranked_nodes = retriever.retrieve(
# #         "Get me all the information about the product")

# #     with st.expander("Ranked Nodes"):
# #         for node in ranked_nodes:
# #             st.write(node.node.get_content(), "-> Score:", node.score)

# #     query_engine = index.as_query_engine(
# #         filters=filters, service_context=service_context)

# #     response = query_engine.query(
# #         "Get me all the information about the product")

# #     st.write(response)