Spaces:
Running
Running
import json | |
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments | |
from pprint import pprint | |
from hf_search import HFSearch | |
import streamlit as st | |
import itertools | |
from pbr.version import VersionInfo | |
print("hf_search version:", VersionInfo('hf_search').version_string()) | |
hf_search = HFSearch(top_k=200) | |
def hf_api(query, limit=5, sort=None, filters={}): | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
print("sort", sort) | |
api = HfApi() | |
filt = ModelFilter( | |
task=filters["task"], | |
library=filters["library"], | |
) | |
models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True) | |
hits = [] | |
for model in models: | |
model = model.__dict__ | |
hits.append( | |
{ | |
"modelId": model.get("modelId"), | |
"tags": model.get("tags"), | |
"downloads": model.get("downloads"), | |
"likes": model.get("likes"), | |
} | |
) | |
count = len(hits) | |
if len(hits) > limit: | |
hits = hits[:limit] | |
return {"hits": hits, "count": count} | |
def semantic_search(query, limit=5, sort=None, filters={}): | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
print("sort", sort) | |
hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters) | |
hits = [ | |
{ | |
"modelId": hit["modelId"], | |
"tags": hit["tags"], | |
"downloads": hit["downloads"], | |
"likes": hit["likes"], | |
"readme": hit.get("readme", None), | |
} | |
for hit in hits | |
] | |
return {"hits": hits, "count": len(hits)} | |
def bm25_search(query, limit=5, sort=None, filters={}): | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
print("sort", sort) | |
# TODO: filters | |
hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters) | |
hits = [ | |
{ | |
"modelId": hit["modelId"], | |
"tags": hit["tags"], | |
"downloads": hit["downloads"], | |
"likes": hit["likes"], | |
"readme": hit.get("readme", None), | |
} | |
for hit in hits | |
] | |
hits = [ | |
hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]] | |
] # unique hits | |
return {"hits": hits, "count": len(hits)} | |
def paginator(label, articles, articles_per_page=10, on_sidebar=True): | |
# https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7 | |
"""Lets the user paginate a set of article. | |
Parameters | |
---------- | |
label : str | |
The label to display over the pagination widget. | |
article : Iterator[Any] | |
The articles to display in the paginator. | |
articles_per_page: int | |
The number of articles to display per page. | |
on_sidebar: bool | |
Whether to display the paginator widget on the sidebar. | |
Returns | |
------- | |
Iterator[Tuple[int, Any]] | |
An iterator over *only the article on that page*, including | |
the item's index. | |
""" | |
# Figure out where to display the paginator | |
if on_sidebar: | |
location = st.sidebar.empty() | |
else: | |
location = st.empty() | |
# Display a pagination selectbox in the specified location. | |
articles = list(articles) | |
n_pages = (len(articles) - 1) // articles_per_page + 1 | |
page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}" | |
page_number = location.selectbox(label, range(n_pages), format_func=page_format_func) | |
# Iterate over the articles in the page to let the user display them. | |
min_index = page_number * articles_per_page | |
max_index = min_index + articles_per_page | |
return itertools.islice(enumerate(articles), min_index, max_index) | |