Federico Galatolo commited on
Commit
5426972
·
1 Parent(s): b0a20ef

added status indicator

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -8,10 +8,14 @@ from sklearn.manifold import TSNE
8
  import plotly.express as plx
9
 
10
  def compare():
11
- if len(multiselect) == 0: return
 
 
 
12
  target_field = f"{model}_features"
13
  ids = [documents[title] for title in multiselect]
14
 
 
15
  results = []
16
  for id in ids:
17
  results.append(es.search(
@@ -28,7 +32,7 @@ def compare():
28
  size=limit
29
  ))
30
 
31
-
32
  features = []
33
  classes = []
34
  sentences = []
@@ -39,6 +43,7 @@ def compare():
39
 
40
  features = np.concatenate(features)
41
 
 
42
  scaler = StandardScaler()
43
  features = scaler.fit_transform(features)
44
  tsne = TSNE(n_components=2, metric="cosine", init="pca")
@@ -53,7 +58,7 @@ def compare():
53
  sentences=sentences
54
  ))
55
 
56
-
57
  plot_placeholder.plotly_chart(plx.scatter(
58
  data_frame=df,
59
  x="x",
@@ -77,5 +82,6 @@ limit = st.sidebar.number_input("Sentences per document", 1000)
77
 
78
  plot_placeholder = st.empty()
79
 
 
80
  if st.sidebar.button("Compare"):
81
  compare()
 
8
  import plotly.express as plx
9
 
10
  def compare():
11
+ if len(multiselect) == 0:
12
+ plot_placeholder.error("Select at least one document")
13
+ return
14
+
15
  target_field = f"{model}_features"
16
  ids = [documents[title] for title in multiselect]
17
 
18
+ status_indicator.write("Retrieving embeddings...")
19
  results = []
20
  for id in ids:
21
  results.append(es.search(
 
32
  size=limit
33
  ))
34
 
35
+ status_indicator.write("Merging embeddings...")
36
  features = []
37
  classes = []
38
  sentences = []
 
43
 
44
  features = np.concatenate(features)
45
 
46
+ status_indicator.write("Computing TSNE...")
47
  scaler = StandardScaler()
48
  features = scaler.fit_transform(features)
49
  tsne = TSNE(n_components=2, metric="cosine", init="pca")
 
58
  sentences=sentences
59
  ))
60
 
61
+ status_indicator.write("All done...")
62
  plot_placeholder.plotly_chart(plx.scatter(
63
  data_frame=df,
64
  x="x",
 
82
 
83
  plot_placeholder = st.empty()
84
 
85
+ status_indicator = st.sidebar.empty()
86
  if st.sidebar.button("Compare"):
87
  compare()