ola13 ToluClassics commited on
Commit
f077846
1 Parent(s): 943373e

Migrate from Gradio to Streamlit (#1)

Browse files

- Migrate from Gradio to Streamlit (f36e63226f3c0075115f552d198d72485b405c57)


Co-authored-by: Odunayo Ogundepo <ToluClassics@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +207 -208
app.py CHANGED
@@ -6,107 +6,100 @@ import pprint
6
  import re
7
  import string
8
 
9
- import gradio as gr
 
10
  import requests
11
 
12
- pp = pprint.PrettyPrinter(indent=2)
13
-
14
 
15
- def get_docid_html(docid):
16
- data_org, dataset, docid = docid.split("/")
17
 
18
- docid_html = """<a
19
- class="underline-on-hover"
20
- title="I am hovering over the text"
21
- style="color:#2D31FA;"
22
- href="https://huggingface.co/datasets/bigscience-data/{}"
23
- target="_blank">{}</a><span style="color: #7978FF;">/{}</span>""".format(
24
- dataset, data_org + "/" + dataset, docid
25
  )
26
- return docid_html
27
-
28
-
29
- PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
30
- PII_PREFIX = "PI:"
31
-
32
-
33
- def process_pii(text):
34
- for tag in PII_TAGS:
35
- text = text.replace(
36
- PII_PREFIX + tag,
37
- """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(tag),
38
- )
39
- return text
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def process_meta_roots(result):
43
- meta_html = (
44
- """
45
- <p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'>
46
- <a href='{}' target='_blank'>{}</a></p>""".format(
47
- result["meta"]["url"], result["meta"]["url"]
48
- )
49
- if "meta" in result and result["meta"] is not None and "url" in result["meta"]
50
- else ""
51
- )
52
 
53
 
 
54
  """
55
- 'meta': { 'docs': [ { 'TEXT': 'Hello World Example Hello '
56
- 'World Page Hello World.',
57
- 'URL': 'http://images.slideplayer.com/8/2335183/slides/slide_6.jpg',
58
- '_id': 592573973},
59
- { 'TEXT': 'Hello World Example Hello '
60
- 'World Page Hello World.',
61
- 'URL': 'http://images.slideplayer.com/8/2335183/slides/slide_9.jpg',
62
- '_id': 1807595732},
63
- { 'TEXT': 'Hello World Example Hello '
64
- 'World Page Hello World.',
65
- 'URL': 'http://images.slideplayer.com/8/2335183/slides/slide_10.jpg',
66
- '_id': 1864921031},
67
- { 'TEXT': 'Hello World Example Hello '
68
- 'World Page Hello World!',
69
- 'URL': 'http://images.slideplayer.com/8/2335183/slides/slide_5.jpg',
70
- '_id': 1964462104},
71
- { 'TEXT': 'Hello World Example Hello '
72
- 'World Page Hello World.',
73
- 'URL': 'http://images.slideplayer.com/8/2335183/slides/slide_8.jpg',
74
- '_id': 2167992166}]},
75
  """
76
- def process_meta_laion(result):
77
- meta_html = """"""
78
- if "meta" not in result:
79
- return meta_html
80
- for doc in result["meta"]["docs"]:
81
- # doc = json.loads(doc)
82
- print(type(doc), doc)
83
- print(doc["URL"])
84
- meta_html += """<p class='underline-on-hover' style='font-size:12px; color:#7978FF; text-align: left;'>
85
- <a href='{}' target='_blank'>{}</a></p>""".format(doc["URL"], doc["URL"])
86
-
87
- return meta_html
88
-
89
- def process_results(results, highlight_terms):
90
- if len(results) == 0:
91
- return """<br><p style='font-family: Arial; color:Silver; text-align: center;'>
92
- No results retrieved.</p><br><hr>"""
93
-
94
- results_html = ""
95
- for result in results:
96
- tokens = result["text"].split()
97
- tokens_html = []
98
- for token in tokens:
99
- if token in highlight_terms:
100
- tokens_html.append("<b>{}</b>".format(token))
101
- else:
102
- tokens_html.append(token)
103
- tokens_html = " ".join(tokens_html)
104
- tokens_html = process_pii(tokens_html)
105
- meta_html = process_meta_laion(result)
106
- docid_html = """<p style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</p>""".format(result["docid"]) # get_docid_html(result["docid"])
107
- language_html = """<p style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</p>""".format(result["lang"])
108
- results_html += tokens_html + meta_html + "<br>"
109
- return results_html + "<hr>"
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  def scisearch(query, language, num_results=10):
@@ -128,8 +121,6 @@ def scisearch(query, language, num_results=10):
128
 
129
  payload = json.loads(output.text)
130
 
131
- pp.pprint(payload)
132
-
133
  if "err" in payload:
134
  if payload["err"]["type"] == "unsupported_lang":
135
  detected_lang = payload["err"]["meta"]["detected_lang"]
@@ -141,38 +132,6 @@ def scisearch(query, language, num_results=10):
141
 
142
  results = payload["results"]
143
  highlight_terms = payload["highlight_terms"]
144
-
145
- if language == "detect_language":
146
- return (
147
- (
148
- f"""<p style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'>
149
- Detected language: <b>{results[0]["lang"]}</b></p><br><hr><br>"""
150
- if len(results) > 0 and language == "detect_language"
151
- else ""
152
- )
153
- + process_results(results, highlight_terms)
154
- )
155
-
156
- if language == "all":
157
- results_html = ""
158
- for lang, results_for_lang in results.items():
159
- if len(results_for_lang) == 0:
160
- results_html += f"""<p style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'>
161
- No results for language: <b>{lang}</b><hr></p>"""
162
- continue
163
-
164
- collapsible_results = f"""
165
- <details>
166
- <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'>
167
- Results for language: <b>{lang}</b><hr>
168
- </summary>
169
- {process_results(results_for_lang, highlight_terms)}
170
- </details>"""
171
- results_html += collapsible_results
172
- return results_html
173
-
174
- return process_results(results, highlight_terms)
175
-
176
  except Exception as e:
177
  results_html = f"""
178
  <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
@@ -183,92 +142,132 @@ def scisearch(query, language, num_results=10):
183
  """
184
  print(e)
185
 
186
- return results_html
187
 
 
 
188
 
189
- def flag(query, language, num_results, issue_description):
190
- try:
191
- post_data = {"query": query, "k": num_results, "flag": True, "description": issue_description}
192
- if language != "detect_language":
193
- post_data["lang"] = language
194
-
195
- output = requests.post(
196
- os.environ.get("address"),
197
- headers={"Content-type": "application/json"},
198
- data=json.dumps(post_data),
199
- timeout=120,
200
- )
201
-
202
- results = json.loads(output.text)
203
- except:
204
- print("Error flagging")
205
- return ""
206
-
207
-
208
- description = """# <p style="text-align: center;">GAIA 🌖🌏</p>
209
- A large scale text corpora search engine."""
210
-
211
-
212
- if __name__ == "__main__":
213
- demo = gr.Blocks(
214
- css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }"
215
- )
216
-
217
- with demo:
218
- with gr.Row():
219
- gr.Markdown(value=description)
220
- with gr.Row():
221
- query = gr.Textbox(lines=2, placeholder="Type your query here...", label="Query")
222
- with gr.Row():
223
- lang = gr.Dropdown(
224
- choices=[
225
- "ar",
226
- "ca",
227
- "code",
228
- "en",
229
- "es",
230
- "eu",
231
- "fr",
232
- "id",
233
- "indic",
234
- "nigercongo",
235
- "pt",
236
- "vi",
237
- "zh",
238
- "detect_language",
239
- "all",
240
- ],
241
- value="en",
242
- label="Language",
243
- )
244
- with gr.Row():
245
- k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
246
- with gr.Row():
247
- submit_btn = gr.Button("Submit")
248
- with gr.Row():
249
- results = gr.HTML(label="Results")
250
- flag_description = """
251
- <p class='flagging'>
252
- If you choose to flag your search, we will save the query, language and the number of results you requested.
253
- Please consider adding any additional context in the box on the right.</p>"""
254
- with gr.Column(visible=False) as flagging_form:
255
- flag_txt = gr.Textbox(
256
- lines=1,
257
- placeholder="Type here...",
258
- label="""If you choose to flag your search, we will save the query, language and the number of results
259
- you requested. Please consider adding relevant additional context below:""",
260
  )
261
- flag_btn = gr.Button("Flag Results")
262
- flag_btn.click(flag, inputs=[query, lang, k, flag_txt], outputs=[flag_txt])
263
 
264
- def submit(query, lang, k):
265
- if query == "":
266
- return ["", ""]
267
- return {
268
- results: scisearch(query, lang, k),
269
- flagging_form: gr.update(visible=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  }
271
-
272
- submit_btn.click(submit, inputs=[query, lang, k], outputs=[results, flagging_form])
273
-
274
- demo.launch(enable_queue=True, debug=True)
 
 
 
 
 
 
6
  import re
7
  import string
8
 
9
+ import streamlit as st
10
+ import streamlit.components.v1 as components
11
  import requests
12
 
 
 
13
 
14
+ pp = pprint.PrettyPrinter(indent=2)
15
+ st.set_page_config(page_title="Gaia Search", layout="wide")
16
 
17
+ os.makedirs(os.path.join(os.getcwd(),".streamlit"), exist_ok = True)
18
+ with open(os.path.join(os.getcwd(),".streamlit/config.toml"), "w") as file:
19
+ file.write(
20
+ '[theme]\nbase="light"'
 
 
 
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ LANG_MAPPING = {'Arabic':'ar',
24
+ 'Catalan':'ca',
25
+ 'Code':'code',
26
+ 'English':'en',
27
+ 'Spanish':'es',
28
+ 'French':'fr',
29
+ 'Indonesian':'id',
30
+ 'Indic':'indic',
31
+ 'Niger-Congo':'nigercongo',
32
+ 'Portuguese': 'pt',
33
+ 'Vietnamese': 'vi',
34
+ 'Chinese': 'zh',
35
+ 'Detect Language':'detect_language',
36
+ 'All':'all'}
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
+ st.sidebar.markdown(
41
  """
42
+ <style>
43
+ .aligncenter {
44
+ text-align: center;
45
+ font-weight: bold;
46
+ font-size: 50px;
47
+ }
48
+ </style>
49
+ <p class="aligncenter">Gaia Search 🌖🌏</p>
50
+ <p style="text-align: center;"> A search engine for the LAION large scale image caption corpora</p>
51
+ """,
52
+ unsafe_allow_html=True,
53
+ )
54
+
55
+ st.sidebar.markdown(
 
 
 
 
 
 
56
  """
57
+ <style>
58
+ .aligncenter {
59
+ text-align: center;
60
+ }
61
+ </style>
62
+ <p style='text-align: center'>
63
+ <a href="" >GitHub</a> | <a href="" >Project Report</a>
64
+ </p>
65
+ <p class="aligncenter">
66
+ <a href="" target="_blank">
67
+ <img src="https://colab.research.google.com/assets/colab-badge.svg"/>
68
+ </a>
69
+ </p>
70
+ """,
71
+ unsafe_allow_html=True,
72
+ )
73
+
74
+ query = st.sidebar.text_input(label='Search query', value='')
75
+ language = st.sidebar.selectbox(
76
+ 'Language',
77
+ ('Arabic', 'Catalan', 'Code', 'English', 'Spanish', 'French', 'Indonesian', 'Indic', 'Niger-Congo', 'Portuguese', 'Vietnamese', 'Chinese', 'Detect Language', 'All'),
78
+ index=3)
79
+ max_results = st.sidebar.slider(
80
+ "Maximum Number of Results",
81
+ min_value=1,
82
+ max_value=100,
83
+ step=1,
84
+ value=10,
85
+ help="Maximum Number of Documents to return",
86
+ )
87
+ footer="""<style>
88
+ .footer {
89
+ position: fixed;
90
+ left: 0;
91
+ bottom: 0;
92
+ width: 100%;
93
+ background-color: white;
94
+ color: black;
95
+ text-align: center;
96
+ }
97
+ </style>
98
+ <div class="footer">
99
+ <p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p>
100
+ </div>
101
+ """
102
+ st.sidebar.markdown(footer,unsafe_allow_html=True)
103
 
104
 
105
  def scisearch(query, language, num_results=10):
 
121
 
122
  payload = json.loads(output.text)
123
 
 
 
124
  if "err" in payload:
125
  if payload["err"]["type"] == "unsupported_lang":
126
  detected_lang = payload["err"]["meta"]["detected_lang"]
 
132
 
133
  results = payload["results"]
134
  highlight_terms = payload["highlight_terms"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
  results_html = f"""
137
  <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
 
142
  """
143
  print(e)
144
 
145
+ return results, highlight_terms
146
 
147
+ PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
148
+ PII_PREFIX = "PI:"
149
 
150
+ def process_pii(text):
151
+ for tag in PII_TAGS:
152
+ text = text.replace(
153
+ PII_PREFIX + tag,
154
+ """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(tag),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
+ return text
 
157
 
158
+ def highlight_string(paragraph: str, highlight_terms: list) -> str:
159
+ for term in highlight_terms:
160
+ paragraph = re.sub(f"\\b{term}\\b", f"<b>{term}</b>", paragraph, flags=re.I)
161
+ paragraph = process_pii(paragraph)
162
+ return paragraph
163
+
164
+ def process_results(hits: list, highlight_terms: list) -> str:
165
+ hit_list = []
166
+ for i, hit in enumerate(hits):
167
+ res_head = f"""
168
+ <div class="searchresult">
169
+ <h2>{i+1}. Document ID: {hit['docid']}</h2>
170
+ <p>Language: <string>{hit['lang']}</string>, Score: {round(hit['score'], 2)}</p>
171
+ """
172
+ for subhit in hit['meta']['docs']:
173
+ res_head += f"""
174
+ <button onclick="load_image({subhit['_id']})">Load Image</button><br>
175
+ <p><img id='{subhit['_id']}' src='{subhit['URL']}' style="width:400px;height:auto;display:none;"></p>
176
+ <a href='{subhit['URL']}'>{subhit['URL']}</a>
177
+ <p>{highlight_string(subhit['TEXT'], highlight_terms)}</p>
178
+ """
179
+ res_head += f"""
180
+ <p>{highlight_string(hit['text'], highlight_terms)}</p>
181
+ </div>
182
+ <hr>
183
+ """
184
+ hit_list.append(res_head)
185
+ return " ".join(hit_list)
186
+
187
+
188
+ if st.sidebar.button("Search"):
189
+ hits, highlight_terms = scisearch(query, LANG_MAPPING[language], max_results)
190
+ html_results = process_results(hits, highlight_terms)
191
+ rendered_results = f"""
192
+ <div id="searchresultsarea">
193
+ <br>
194
+ <p id="searchresultsnumber">About {max_results} results</p>
195
+ {html_results}
196
+ </div>
197
+ """
198
+ st.markdown("""
199
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet"
200
+ integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
201
+ """,
202
+ unsafe_allow_html=True)
203
+ st.markdown(
204
+ """
205
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
206
+ """,
207
+ unsafe_allow_html=True)
208
+ st.markdown(
209
+ f"""
210
+ <div class="row no-gutters mt-3 align-items-center">
211
+ Gaia Search 🌖🌏
212
+ <div class="col col-md-4">
213
+ <input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2">
214
+ </div>
215
+ <div class="col-auto">
216
+ <button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button">
217
+ <i class="fa fa-search"></i>
218
+ </button>
219
+ </div>
220
+ </div>
221
+ """,
222
+ unsafe_allow_html=True)
223
+ components.html(
224
+ """
225
+ <style>
226
+ #searchresultsarea {
227
+ font-family: 'Arial';
228
+ }
229
+
230
+ #searchresultsnumber {
231
+ font-size: 0.8rem;
232
+ color: gray;
233
+ }
234
+
235
+ .searchresult h2 {
236
+ font-size: 19px;
237
+ line-height: 18px;
238
+ font-weight: normal;
239
+ color: rgb(7, 111, 222);
240
+ margin-bottom: 0px;
241
+ margin-top: 25px;
242
+ }
243
+
244
+ .searchresult a {
245
+ font-size: 12px;
246
+ line-height: 12px;
247
+ color: green;
248
+ margin-bottom: 0px;
249
+ }
250
+
251
+ .dark-mode {
252
+ color: white;
253
+ }
254
+ </style>
255
+ <script>
256
+ function load_image(id){
257
+ console.log(id)
258
+ var x = document.getElementById(id);
259
+ console.log(x)
260
+ if (x.style.display === "none") {
261
+ x.style.display = "block";
262
+ } else {
263
+ x.style.display = "none";
264
  }
265
+ };
266
+ function myFunction() {
267
+ var element = document.body;
268
+ element.classList.toggle("dark-mode");
269
+ }
270
+ </script>
271
+ <button onclick="myFunction()">Toggle dark mode</button>
272
+ """ + rendered_results, height=800, scrolling=True
273
+ )