Spaces:
Running
Running
from flask import Flask, request | |
import json | |
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments | |
from pprint import pprint | |
from hf_search import hf_search | |
app = Flask(__name__) | |
def hello(): | |
return "<h1 style='color:blue'>Hello There!</h1>" | |
def hf_api(): | |
request_data = request.get_json() | |
query = request_data.get("query") | |
filters = json.loads(request_data.get("filters")) | |
limit = request_data.get("limit", 5) | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
api = HfApi() | |
filt = ModelFilter( | |
task=filters["task"], | |
library=filters["library"], | |
) | |
models = api.list_models(search=query, filter=filt, limit=limit, full=True) | |
hits = [] | |
for model in models: | |
model = model.__dict__ | |
hits.append( | |
{ | |
"modelId": model.get("modelId"), | |
"tags": model.get("tags"), | |
"downloads": model.get("downloads"), | |
"likes": model.get("likes"), | |
} | |
) | |
count = len(hits) | |
if len(hits) > limit: | |
hits = hits[:limit] | |
pprint(hits) | |
return json.dumps({"value": hits, "count": count}) | |
def semantic_search(): | |
request_data = request.get_json() | |
query = request_data.get("query") | |
filters = json.loads(request_data.get("filters")) | |
limit = request_data.get("limit", 5) | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
hits = hf_search(query=query, method="retrieve & rerank", limit=limit, filters=filters) | |
hits = [ | |
{ | |
"modelId": hit["modelId"], | |
"tags": hit["tags"], | |
"downloads": hit["downloads"], | |
"likes": hit["likes"], | |
"readme": hit.get("readme", None), | |
} | |
for hit in hits | |
] | |
return json.dumps({"value": hits, "count": len(hits)}) | |
def bm25_search(): | |
request_data = request.get_json() | |
query = request_data.get("query") | |
filters = json.loads(request_data.get("filters")) | |
limit = request_data.get("limit", 5) | |
print("query", query) | |
print("filters", filters) | |
print("limit", limit) | |
# TODO: filters | |
hits = hf_search(query=query, method="bm25", limit=limit) | |
hits = [ | |
{ | |
"modelId": hit["modelId"], | |
"tags": hit["tags"], | |
"downloads": hit["downloads"], | |
"likes": hit["likes"], | |
"readme": hit.get("readme", None), | |
} | |
for hit in hits | |
] | |
hits = [hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]] # unique hits | |
return json.dumps({"value": hits, "count": len(hits)}) | |
if __name__ == "__main__": | |
app.run(host="localhost", port=5000) | |