hf-search / pages /search_engine.py
nouamanetazi's picture
nouamanetazi HF staff
quick fixes
cd6abe6
raw
history blame
5.48 kB
import os
import re
import json
import datetime
import itertools
import requests
from PIL import Image
import base64
import streamlit as st
from huggingface_hub import ModelSearchArguments
import webbrowser
from numerize.numerize import numerize
def paginator(label, articles, articles_per_page=10, on_sidebar=True):
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
"""Lets the user paginate a set of article.
Parameters
----------
label : str
The label to display over the pagination widget.
article : Iterator[Any]
The articles to display in the paginator.
articles_per_page: int
The number of articles to display per page.
on_sidebar: bool
Whether to display the paginator widget on the sidebar.
Returns
-------
Iterator[Tuple[int, Any]]
An iterator over *only the article on that page*, including
the item's index.
"""
# Figure out where to display the paginator
if on_sidebar:
location = st.sidebar.empty()
else:
location = st.empty()
# Display a pagination selectbox in the specified location.
articles = list(articles)
n_pages = (len(articles) - 1) // articles_per_page + 1
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)
# Iterate over the articles in the page to let the user display them.
min_index = page_number * articles_per_page
max_index = min_index + articles_per_page
return itertools.islice(enumerate(articles), min_index, max_index)
def page():
### SIDEBAR
search_backend = st.sidebar.selectbox(
"Search method",
["semantic", "bm25", "hfapi"],
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x],
)
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
st.sidebar.markdown("# Filters")
args = ModelSearchArguments()
library = st.sidebar.multiselect(
"Library", args.library.values(), format_func=lambda x: {v: k for k, v in args.library.items()}[x]
)
task = st.sidebar.multiselect(
"Task", args.pipeline_tag.values(), format_func=lambda x: {v: k for k, v in args.pipeline_tag.items()}[x]
)
### MAIN PAGE
st.markdown(
"<h1 style='text-align: center; '>πŸ”ŽπŸ€— HF Search Engine</h1>",
unsafe_allow_html=True,
)
# Search bar
search_query = st.text_input(
"Search for a model in HuggingFace", value="", max_chars=None, key=None, type="default"
)
# Search API
endpoint = "http://localhost:5000"
headers = {
"Content-Type": "application/json",
"api-key": "password",
}
search_url = f"{endpoint}/{search_backend}/search"
filters = {
"library": library,
"task": task,
}
search_body = {
"query": search_query,
"filters": json.dumps(filters, default=str),
"limit": limit_results,
}
if search_query != "":
response = requests.post(search_url, headers=headers, json=search_body).json()
hit_list = []
_ = [
hit_list.append(
{
"modelId": hit["modelId"],
"tags": hit["tags"],
"downloads": hit["downloads"],
"likes": hit["likes"],
"readme": hit.get("readme", None),
}
)
for hit in response.get("value")
]
if hit_list:
st.write(f'Search results ({response.get("count")}):')
if response.get("count") > 100:
shown_results = 100
else:
shown_results = response.get("count")
for i, hit in paginator(
f"Select results (showing {shown_results} of {response.get('count')} results)",
hit_list,
):
col1, col2, col3 = st.columns([5,1,1])
col1.metric("Model", hit["modelId"])
col2.metric("NΒ° downloads", numerize(hit["downloads"]))
col3.metric("NΒ° likes", numerize(hit["likes"]))
st.button(f"View model on πŸ€—", on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"), key=hit["modelId"])
st.write(f"**Tags:** {'&nbsp;&nbsp;β€’&nbsp;&nbsp;'.join(hit['tags'])}")
if hit["readme"]:
with st.expander("See README"):
st.write(hit["readme"])
# TODO: embed huggingface spaces
# import streamlit.components.v1 as components
# components.html(
# f"""
# <link rel="stylesheet" href="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.css">
# <div id="target"></div>
# <script src="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.js"></script>
# <script>
# launchGradioFromSpaces("abidlabs/question-answering", "#target")
# </script>
# """,
# height=400,
# )
st.markdown("---")
else:
st.write(f"No Search results, please try again with different keywords")