Spaces:
Running
Running
import os | |
import re | |
import json | |
import datetime | |
import itertools | |
import requests | |
from PIL import Image | |
import base64 | |
import streamlit as st | |
from huggingface_hub import ModelSearchArguments | |
import webbrowser | |
from numerize.numerize import numerize | |
def paginator(label, articles, articles_per_page=10, on_sidebar=True): | |
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7 | |
"""Lets the user paginate a set of article. | |
Parameters | |
---------- | |
label : str | |
The label to display over the pagination widget. | |
article : Iterator[Any] | |
The articles to display in the paginator. | |
articles_per_page: int | |
The number of articles to display per page. | |
on_sidebar: bool | |
Whether to display the paginator widget on the sidebar. | |
Returns | |
------- | |
Iterator[Tuple[int, Any]] | |
An iterator over *only the article on that page*, including | |
the item's index. | |
""" | |
# Figure out where to display the paginator | |
if on_sidebar: | |
location = st.sidebar.empty() | |
else: | |
location = st.empty() | |
# Display a pagination selectbox in the specified location. | |
articles = list(articles) | |
n_pages = (len(articles) - 1) // articles_per_page + 1 | |
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}" | |
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func) | |
# Iterate over the articles in the page to let the user display them. | |
min_index = page_number * articles_per_page | |
max_index = min_index + articles_per_page | |
return itertools.islice(enumerate(articles), min_index, max_index) | |
def page(): | |
### SIDEBAR | |
search_backend = st.sidebar.selectbox( | |
"Search method", | |
["semantic", "bm25", "hfapi"], | |
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x], | |
) | |
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10) | |
st.sidebar.markdown("# Filters") | |
args = ModelSearchArguments() | |
library = st.sidebar.multiselect( | |
"Library", args.library.values(), format_func=lambda x: {v: k for k, v in args.library.items()}[x] | |
) | |
task = st.sidebar.multiselect( | |
"Task", args.pipeline_tag.values(), format_func=lambda x: {v: k for k, v in args.pipeline_tag.items()}[x] | |
) | |
### MAIN PAGE | |
st.markdown( | |
"<h1 style='text-align: center; '>ππ€ HF Search Engine</h1>", | |
unsafe_allow_html=True, | |
) | |
# Search bar | |
search_query = st.text_input( | |
"Search for a model in HuggingFace", value="", max_chars=None, key=None, type="default" | |
) | |
# Search API | |
endpoint = "http://localhost:5000" | |
headers = { | |
"Content-Type": "application/json", | |
"api-key": "password", | |
} | |
search_url = f"{endpoint}/{search_backend}/search" | |
filters = { | |
"library": library, | |
"task": task, | |
} | |
search_body = { | |
"query": search_query, | |
"filters": json.dumps(filters, default=str), | |
"limit": limit_results, | |
} | |
if search_query != "": | |
response = requests.post(search_url, headers=headers, json=search_body).json() | |
hit_list = [] | |
_ = [ | |
hit_list.append( | |
{ | |
"modelId": hit["modelId"], | |
"tags": hit["tags"], | |
"downloads": hit["downloads"], | |
"likes": hit["likes"], | |
"readme": hit.get("readme", None), | |
} | |
) | |
for hit in response.get("value") | |
] | |
if hit_list: | |
st.write(f'Search results ({response.get("count")}):') | |
if response.get("count") > 100: | |
shown_results = 100 | |
else: | |
shown_results = response.get("count") | |
for i, hit in paginator( | |
f"Select results (showing {shown_results} of {response.get('count')} results)", | |
hit_list, | |
): | |
col1, col2, col3 = st.columns([5,1,1]) | |
col1.metric("Model", hit["modelId"]) | |
col2.metric("NΒ° downloads", numerize(hit["downloads"])) | |
col3.metric("NΒ° likes", numerize(hit["likes"])) | |
st.button(f"View model on π€", on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"), key=hit["modelId"]) | |
st.write(f"**Tags:** {' β’ '.join(hit['tags'])}") | |
if hit["readme"]: | |
with st.expander("See README"): | |
st.write(hit["readme"]) | |
# TODO: embed huggingface spaces | |
# import streamlit.components.v1 as components | |
# components.html( | |
# f""" | |
# <link rel="stylesheet" href="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.css"> | |
# <div id="target"></div> | |
# <script src="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.js"></script> | |
# <script> | |
# launchGradioFromSpaces("abidlabs/question-answering", "#target") | |
# </script> | |
# """, | |
# height=400, | |
# ) | |
st.markdown("---") | |
else: | |
st.write(f"No Search results, please try again with different keywords") | |