Spaces:
Sleeping
Sleeping
File size: 4,415 Bytes
94ad0d9 d845ab5 94ad0d9 d845ab5 5158933 94ad0d9 d845ab5 7f91778 d845ab5 a6d5b0f eb5bdea 7f91778 a6d5b0f eb5bdea 5f1d89a eb5bdea 5f1d89a d845ab5 f27dba6 d845ab5 5158933 7b92747 d845ab5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import streamlit as st
import pandas as pd
import plotly.express as px
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.manifold import TSNE
import plotly.express as px
import torch
import plotly.io as pio
pio.templates.default = "plotly"
st. set_page_config(layout="wide")
st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5] _",divider='violet')
st.write("Matryoshka Representation Learning : to learn more :https://aniketrege.github.io/blog/2024/mrl/")
@st.cache_data
def get_df():
prodDf = pd.read_csv("./sample_products.csv")
return prodDf
@st.cache_resource
def get_nomicModel():
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
return model
def get_searchQueryEmbedding(query):
embeddings = model.encode(["search_query: "+query], convert_to_tensor=True)
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
return embeddings
def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim):
query_embedNorm = query_embedding[:, :matryoshka_dim]
query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1)
loaded_embedNorm = loaded_embed[:, :matryoshka_dim]
loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1)
return query_embedNorm,loaded_embedNorm
def insert_line_breaks(text, interval=30):
words = text.split(' ')
wrapped_text = ''
line_length = 0
for word in words:
wrapped_text += word + ' '
line_length += len(word) + 1
if line_length >= interval:
wrapped_text += '<br>'
line_length = 0
return wrapped_text.strip()
# Automatically wrap the hover text
model = get_nomicModel()
bigDollEmbedding = get_df()["Description"]
docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy"))
toggle = st.toggle('sample queries')
with st.form("my_form"):
if toggle:
query_input = st.selectbox('select a query:',
('Pack of two assorted boxers, has two pockets, an elasticated waistbandDisclaimer: The final product delivered might vary in colour and prints from the display here.',
'Beige self design shoulder bag, has a zip closure1 main compartment, 3 inner pocketsTwo Handles',
'Set Content: 1 photo frameColour: Black and whiteFrame Pattern: SolidShape: SquareMaterial: Acrylic',
'A pair of dark grey solid boxers, has a slip-on closure with an elasticated waistband and drawstring, two pocket',
'Red & Black solid sweatshirt, has a hood, two pockets, long sleeves, zip closure, straight hem'))
else:
query_input = st.text_input("")
Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64)
submitted = st.form_submit_button("Submit")
if submitted:
queryEmbedding = get_searchQueryEmbedding(query_input)
query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim)
similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T)
top_values, top_indices = torch.topk(similarity_scores, 10, dim=1)
to_index = list(top_indices.numpy()[0])
top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index]
print(top_values)
df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]})
df["Product"] = df["Product"].str.replace("search_document:","")
# st.dataframe(df)
allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm])
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(allEmbedd)
listHover = bigDollEmbedding.tolist()
listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover]
fig = px.scatter(
projections, x=0, y=1,
hover_name=[query_input]+listHover,
color=["search_query"]+(["search_document"]*270)
)
col1, col2 = st.columns([2, 2])
col2.plotly_chart(fig, use_container_width=True)
col1.dataframe(df)
st.caption("Dataset Credit : kaggle")
|