hf-search / server /api.py
nouamanetazi's picture
nouamanetazi HF staff
quick fixes
cd6abe6
raw
history blame
2.93 kB
from flask import Flask, request
import json
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
from pprint import pprint
from hf_search import hf_search
app = Flask(__name__)
@app.route("/hello")
def hello():
return "<h1 style='color:blue'>Hello There!</h1>"
@app.route("/hfapi/search", methods=["POST"])
def hf_api():
request_data = request.get_json()
query = request_data.get("query")
filters = json.loads(request_data.get("filters"))
limit = request_data.get("limit", 5)
print("query", query)
print("filters", filters)
print("limit", limit)
api = HfApi()
filt = ModelFilter(
task=filters["task"],
library=filters["library"],
)
models = api.list_models(search=query, filter=filt, limit=limit, full=True)
hits = []
for model in models:
model = model.__dict__
hits.append(
{
"modelId": model.get("modelId"),
"tags": model.get("tags"),
"downloads": model.get("downloads"),
"likes": model.get("likes"),
}
)
count = len(hits)
if len(hits) > limit:
hits = hits[:limit]
pprint(hits)
return json.dumps({"value": hits, "count": count})
@app.route("/semantic/search", methods=["POST"])
def semantic_search():
request_data = request.get_json()
query = request_data.get("query")
filters = json.loads(request_data.get("filters"))
limit = request_data.get("limit", 5)
print("query", query)
print("filters", filters)
print("limit", limit)
hits = hf_search(query=query, method="retrieve & rerank", limit=limit, filters=filters)
hits = [
{
"modelId": hit["modelId"],
"tags": hit["tags"],
"downloads": hit["downloads"],
"likes": hit["likes"],
"readme": hit.get("readme", None),
}
for hit in hits
]
return json.dumps({"value": hits, "count": len(hits)})
@app.route("/bm25/search", methods=["POST"])
def bm25_search():
request_data = request.get_json()
query = request_data.get("query")
filters = json.loads(request_data.get("filters"))
limit = request_data.get("limit", 5)
print("query", query)
print("filters", filters)
print("limit", limit)
# TODO: filters
hits = hf_search(query=query, method="bm25", limit=limit)
hits = [
{
"modelId": hit["modelId"],
"tags": hit["tags"],
"downloads": hit["downloads"],
"likes": hit["likes"],
"readme": hit.get("readme", None),
}
for hit in hits
]
hits = [hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]] # unique hits
return json.dumps({"value": hits, "count": len(hits)})
if __name__ == "__main__":
app.run(host="localhost", port=5000)