Spaces:
Running
Running
Commit
·
cf61e60
1
Parent(s):
da75a62
update routes
Browse files- pages/search_engine.py +23 -32
- server/api.py +37 -2
- streamlit_app.py +11 -0
pages/search_engine.py
CHANGED
@@ -50,25 +50,12 @@ def paginator(label, articles, articles_per_page=10, on_sidebar=True):
|
|
50 |
|
51 |
return itertools.islice(enumerate(articles), min_index, max_index)
|
52 |
|
53 |
-
|
54 |
def page():
|
55 |
-
st.set_page_config(
|
56 |
-
page_title="HF Search Engine",
|
57 |
-
page_icon="🔎",
|
58 |
-
layout="wide",
|
59 |
-
initial_sidebar_state="auto",
|
60 |
-
# menu_items={
|
61 |
-
# "Get Help": "https://www.extremelycoolapp.com/help",
|
62 |
-
# "Report a bug": "https://www.extremelycoolapp.com/bug",
|
63 |
-
# "About": "# This is a header. This is an *extremely* cool app!",
|
64 |
-
# },
|
65 |
-
)
|
66 |
-
|
67 |
### SIDEBAR
|
68 |
search_backend = st.sidebar.selectbox(
|
69 |
-
"Search
|
70 |
-
["
|
71 |
-
format_func=lambda x: {"hfapi": "
|
72 |
)
|
73 |
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
|
74 |
|
@@ -112,22 +99,22 @@ def page():
|
|
112 |
if search_query != "":
|
113 |
response = requests.post(search_url, headers=headers, json=search_body).json()
|
114 |
|
115 |
-
|
116 |
_ = [
|
117 |
-
|
118 |
{
|
119 |
-
"modelId":
|
120 |
-
"tags":
|
121 |
-
"downloads":
|
122 |
-
"likes":
|
|
|
123 |
}
|
124 |
)
|
125 |
-
for
|
126 |
]
|
127 |
|
128 |
-
# filter results
|
129 |
|
130 |
-
if
|
131 |
st.write(f'Search results ({response.get("count")}):')
|
132 |
|
133 |
if response.get("count") > 100:
|
@@ -135,16 +122,20 @@ def page():
|
|
135 |
else:
|
136 |
shown_results = response.get("count")
|
137 |
|
138 |
-
for i,
|
139 |
f"Select results (showing {shown_results} of {response.get('count')} results)",
|
140 |
-
|
141 |
):
|
142 |
col1, col2, col3 = st.columns([5,1,1])
|
143 |
-
col1.metric("Model",
|
144 |
-
col2.metric("N° downloads", numerize(
|
145 |
-
col3.metric("N° likes", numerize(
|
146 |
-
st.button(f"View model", on_click=lambda
|
147 |
-
st.markdown(f"**Tags:** {' • '.join(
|
|
|
|
|
|
|
|
|
148 |
|
149 |
# TODO: embed huggingface spaces
|
150 |
# import streamlit.components.v1 as components
|
|
|
50 |
|
51 |
return itertools.islice(enumerate(articles), min_index, max_index)
|
52 |
|
|
|
53 |
def page():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
### SIDEBAR
|
55 |
search_backend = st.sidebar.selectbox(
|
56 |
+
"Search method",
|
57 |
+
["semantic", "bm25", "hfapi"],
|
58 |
+
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x],
|
59 |
)
|
60 |
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
|
61 |
|
|
|
99 |
if search_query != "":
|
100 |
response = requests.post(search_url, headers=headers, json=search_body).json()
|
101 |
|
102 |
+
hit_list = []
|
103 |
_ = [
|
104 |
+
hit_list.append(
|
105 |
{
|
106 |
+
"modelId": hit["modelId"],
|
107 |
+
"tags": hit["tags"],
|
108 |
+
"downloads": hit["downloads"],
|
109 |
+
"likes": hit["likes"],
|
110 |
+
"readme": hit.get("readme", None),
|
111 |
}
|
112 |
)
|
113 |
+
for hit in response.get("value")
|
114 |
]
|
115 |
|
|
|
116 |
|
117 |
+
if hit_list:
|
118 |
st.write(f'Search results ({response.get("count")}):')
|
119 |
|
120 |
if response.get("count") > 100:
|
|
|
122 |
else:
|
123 |
shown_results = response.get("count")
|
124 |
|
125 |
+
for i, hit in paginator(
|
126 |
f"Select results (showing {shown_results} of {response.get('count')} results)",
|
127 |
+
hit_list,
|
128 |
):
|
129 |
col1, col2, col3 = st.columns([5,1,1])
|
130 |
+
col1.metric("Model", hit["modelId"])
|
131 |
+
col2.metric("N° downloads", numerize(hit["downloads"]))
|
132 |
+
col3.metric("N° likes", numerize(hit["likes"]))
|
133 |
+
st.button(f"View model on 🤗", on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"), key=hit["modelId"])
|
134 |
+
st.markdown(f"**Tags:** {' • '.join(hit['tags'])}")
|
135 |
+
|
136 |
+
if hit["readme"]:
|
137 |
+
with st.expander("See README"):
|
138 |
+
st.write(hit["readme"])
|
139 |
|
140 |
# TODO: embed huggingface spaces
|
141 |
# import streamlit.components.v1 as components
|
server/api.py
CHANGED
@@ -46,8 +46,8 @@ def hf_api():
|
|
46 |
return json.dumps({"value": hits, "count": count})
|
47 |
|
48 |
|
49 |
-
@app.route("/
|
50 |
-
def
|
51 |
request_data = request.get_json()
|
52 |
query = request_data.get("query")
|
53 |
filters = json.loads(request_data.get("filters"))
|
@@ -58,6 +58,41 @@ def main():
|
|
58 |
|
59 |
# TODO: filters
|
60 |
hits = hf_search(query=query, method="retrieve & rerank", limit=limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
return json.dumps({"value": hits, "count": len(hits)})
|
63 |
|
|
|
46 |
return json.dumps({"value": hits, "count": count})
|
47 |
|
48 |
|
49 |
+
@app.route("/semantic/search", methods=["POST"])
|
50 |
+
def semantic_search():
|
51 |
request_data = request.get_json()
|
52 |
query = request_data.get("query")
|
53 |
filters = json.loads(request_data.get("filters"))
|
|
|
58 |
|
59 |
# TODO: filters
|
60 |
hits = hf_search(query=query, method="retrieve & rerank", limit=limit)
|
61 |
+
hits = [
|
62 |
+
{
|
63 |
+
"modelId": hit["modelId"],
|
64 |
+
"tags": hit["tags"],
|
65 |
+
"downloads": hit["downloads"],
|
66 |
+
"likes": hit["likes"],
|
67 |
+
"readme": hit.get("readme", None),
|
68 |
+
}
|
69 |
+
for hit in hits
|
70 |
+
]
|
71 |
+
return json.dumps({"value": hits, "count": len(hits)})
|
72 |
+
|
73 |
+
@app.route("/bm25/search", methods=["POST"])
|
74 |
+
def bm25_search():
|
75 |
+
request_data = request.get_json()
|
76 |
+
query = request_data.get("query")
|
77 |
+
filters = json.loads(request_data.get("filters"))
|
78 |
+
limit = request_data.get("limit", 5)
|
79 |
+
print("query", query)
|
80 |
+
print("filters", filters)
|
81 |
+
print("limit", limit)
|
82 |
+
|
83 |
+
# TODO: filters
|
84 |
+
hits = hf_search(query=query, method="bm25", limit=limit)
|
85 |
+
hits = [
|
86 |
+
{
|
87 |
+
"modelId": hit["modelId"],
|
88 |
+
"tags": hit["tags"],
|
89 |
+
"downloads": hit["downloads"],
|
90 |
+
"likes": hit["likes"],
|
91 |
+
"readme": hit.get("readme", None),
|
92 |
+
}
|
93 |
+
for hit in hits
|
94 |
+
]
|
95 |
+
pprint(hits)
|
96 |
|
97 |
return json.dumps({"value": hits, "count": len(hits)})
|
98 |
|
streamlit_app.py
CHANGED
@@ -11,6 +11,17 @@ def set_record(record):
|
|
11 |
|
12 |
|
13 |
if not st.session_state["selected_record"]: # search engine page
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
search_engine_page()
|
15 |
|
16 |
else: # a record has been selected
|
|
|
11 |
|
12 |
|
13 |
if not st.session_state["selected_record"]: # search engine page
|
14 |
+
st.set_page_config(
|
15 |
+
page_title="HuggingFace Search Engine",
|
16 |
+
page_icon="🔎",
|
17 |
+
layout="wide",
|
18 |
+
initial_sidebar_state="auto",
|
19 |
+
# menu_items={
|
20 |
+
# "Get Help": "https://www.extremelycoolapp.com/help",
|
21 |
+
# "Report a bug": "https://www.extremelycoolapp.com/bug",
|
22 |
+
# "About": "# This is a header. This is an *extremely* cool app!",
|
23 |
+
# },
|
24 |
+
)
|
25 |
search_engine_page()
|
26 |
|
27 |
else: # a record has been selected
|