Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import torch.nn.functional as F | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.manifold import TSNE | |
import plotly.express as px | |
import torch | |
import plotly.io as pio | |
pio.templates.default = "plotly" | |
st. set_page_config(layout="wide") | |
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5] _",divider='violet') | |
st.write("Matryoshka Representation Learning : to learn more :https://aniketrege.github.io/blog/2024/mrl/") | |
def get_df(): | |
prodDf = pd.read_csv("./sample_products.csv") | |
return prodDf | |
def get_nomicModel(): | |
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
return model | |
def get_searchQueryEmbedding(query): | |
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True) | |
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) | |
return embeddings | |
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim): | |
query_embedNorm = query_embedding[:, :matryoshka_dim] | |
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1) | |
loaded_embedNorm = loaded_embed[:, :matryoshka_dim] | |
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1) | |
return query_embedNorm,loaded_embedNorm | |
def insert_line_breaks(text, interval=30): | |
words = text.split(' ') | |
wrapped_text = '' | |
line_length = 0 | |
for word in words: | |
wrapped_text += word + ' ' | |
line_length += len(word) + 1 | |
if line_length >= interval: | |
wrapped_text += '<br>' | |
line_length = 0 | |
return wrapped_text.strip() | |
# Automatically wrap the hover text | |
model = get_nomicModel() | |
bigDollEmbedding = get_df()["Description"] | |
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy")) | |
toggle = st.toggle('sample queries') | |
with st.form("my_form"): | |
if toggle: | |
question_input = st.selectbox('select a query:', | |
('Pack of two assorted boxers, has two pockets, an elasticated waistbandDisclaimer: The final product delivered might vary in colour and prints from the display here.', | |
'Beige self design shoulder bag, has a zip closure1 main compartment, 3 inner pocketsTwo Handles', | |
'Set Content: 1 photo frameColour: Black and whiteFrame Pattern: SolidShape: SquareMaterial: Acrylic', | |
'A pair of dark grey solid boxers, has a slip-on closure with an elasticated waistband and drawstring, two pocket', | |
'Red & Black solid sweatshirt, has a hood, two pockets, long sleeves, zip closure, straight hem')) | |
else: | |
question_input = st.text_input("") | |
submitted = st.form_submit_button("Submit") | |
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64) | |
if submitted: | |
queryEmbedding = get_searchQueryEmbedding(query_input) | |
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim) | |
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T) | |
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1) | |
to_index = list(top_indices.numpy()[0]) | |
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index] | |
print(top_values) | |
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]}) | |
df["Product"] = df["Product"].str.replace("search_document:","") | |
# st.dataframe(df) | |
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm]) | |
tsne = TSNE(n_components=2, random_state=0) | |
projections = tsne.fit_transform(allEmbedd) | |
listHover = bigDollEmbedding.tolist() | |
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover] | |
fig = px.scatter( | |
projections, x=0, y=1, | |
hover_name=[query_input]+listHover, | |
color=["search_query"]+(["search_document"]*270) | |
) | |
col1, col2 = st.columns([2, 2]) | |
col2.plotly_chart(fig, use_container_width=True) | |
col1.dataframe(df) | |
st.caption("Dataset Credit : kaggle") | |