PrabakaranC commited on
Commit
d845ab5
·
verified ·
1 Parent(s): 50d230c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -2
app.py CHANGED
@@ -1,8 +1,111 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
 
 
 
 
 
 
 
 
 
4
 
5
- st.header("Explore the Russian Dolls - Nomic Embed 1.5",divider='violet')
 
6
  st.write("matryoshka representation learning")
7
 
8
- query_input = st.text_input("query your product")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
4
+ import torch.nn.functional as F
5
+ from sentence_transformers import SentenceTransformer
6
+ import numpy as np
7
+ from sklearn.manifold import TSNE
8
+ import plotly.express as px
9
+ import torch
10
+ import plotly.io as pio
11
+ pio.templates.default = "plotly"
12
+
13
 
14
+ st. set_page_config(layout="wide")
15
+ st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5]_",divider='violet')
16
  st.write("matryoshka representation learning")
17
 
18
+
19
+ @st.cache_data
20
+ def get_df():
21
+ prodDf = pd.read_csv("./sample_products.csv")
22
+ return prodDf
23
+
24
+ @st.cache_resource
25
+ def get_nomicModel():
26
+ model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
27
+ return model
28
+
29
+ def get_searchQueryEmbedding(query):
30
+ embeddings = model.encode(["search_query: "+query], convert_to_tensor=True)
31
+ embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
32
+ return embeddings
33
+
34
+ def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim):
35
+ query_embedNorm = query_embedding[:, :matryoshka_dim]
36
+ query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1)
37
+ loaded_embedNorm = loaded_embed[:, :matryoshka_dim]
38
+ loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1)
39
+ return query_embedNorm,loaded_embedNorm
40
+
41
+ def insert_line_breaks(text, interval=30):
42
+ words = text.split(' ')
43
+ wrapped_text = ''
44
+ line_length = 0
45
+ for word in words:
46
+ wrapped_text += word + ' '
47
+ line_length += len(word) + 1
48
+ if line_length >= interval:
49
+ wrapped_text += '<br>'
50
+ line_length = 0
51
+ return wrapped_text.strip()
52
+
53
+ # Automatically wrap the hover text
54
+
55
+
56
+ model = get_nomicModel()
57
+ bigDollEmbedding = get_df()["Description"]
58
+ docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy"))
59
+
60
+
61
+
62
+
63
+
64
+ with st.form("my_form"):
65
+ query_input = st.text_input("query your product")
66
+
67
+
68
+ sample_products = ["a","b","c"]
69
+ submitted = st.form_submit_button("Submit")
70
+
71
+ if submitted:
72
+ queryEmbedding = get_searchQueryEmbedding(query_input)
73
+ Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64)
74
+ query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim)
75
+
76
+ similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T)
77
+ top_values, top_indices = torch.topk(similarity_scores, 10, dim=1)
78
+ to_index = list(top_indices.numpy()[0])
79
+ top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index]
80
+
81
+ df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]})
82
+ df["Product"] = df["Product"].str.replace("search_document:","")
83
+ # st.dataframe(df)
84
+
85
+ allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm])
86
+
87
+ tsne = TSNE(n_components=2, random_state=0)
88
+
89
+ projections = tsne.fit_transform(allEmbedd)
90
+
91
+ listHover = bigDollEmbedding.tolist()
92
+ listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover]
93
+
94
+
95
+ fig = px.scatter(
96
+ projections, x=0, y=1,
97
+ hover_name=[query_input]+listHover,
98
+
99
+ color=["search_query"]+(["search_document"]*270)
100
+ )
101
+
102
+ col1, col2 = st.columns([2, 2])
103
+
104
+ col2.plotly_chart(fig, use_container_width=True)
105
+ col1.dataframe(df)
106
+
107
+
108
+
109
+
110
+
111
+