Spaces:
Running
Running
option to switch models
Browse files- app.py +21 -4
- utils/haystack.py +14 -6
- utils/ui.py +27 -11
app.py
CHANGED
@@ -15,10 +15,27 @@ set_initial_state()
|
|
15 |
sidebar()
|
16 |
|
17 |
st.write("# Get the summaries of latest top Hacker News posts π§‘")
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
search_bar, button = st.columns(2)
|
23 |
# Search bar
|
24 |
with search_bar:
|
@@ -29,7 +46,7 @@ if st.session_state.get("HF_TGI_TOKEN"):
|
|
29 |
st.write("")
|
30 |
run_pressed = st.button("Get summaries")
|
31 |
else:
|
32 |
-
st.write("Please provide your Hugging Face
|
33 |
st.write("If you are using a smaller screen, open the sidebar from the top left to provide your token π")
|
34 |
|
35 |
if st.session_state.get("api_key_configured"):
|
|
|
15 |
sidebar()
|
16 |
|
17 |
st.write("# Get the summaries of latest top Hacker News posts π§‘")
|
18 |
+
if st.session_state.get("model") == None:
|
19 |
+
mistral, openai = st.columns(2)
|
20 |
|
21 |
+
with mistral:
|
22 |
+
mistral_pressed = st.button("Mistral")
|
23 |
+
if mistral_pressed:
|
24 |
+
st.session_state["model"] = "Mistral"
|
25 |
+
with openai:
|
26 |
+
openai_pressed = st.button("OpenAI")
|
27 |
+
if openai_pressed:
|
28 |
+
st.session_state["model"] = "GPT-4"
|
29 |
+
|
30 |
+
if st.session_state.get("model") and (st.session_state.get("HF_TGI_TOKEN") or st.session_state.get("OPENAI_API_KEY")):
|
31 |
+
if st.session_state.get("HF_TGI_TOKEN"):
|
32 |
+
pipeline = start_haystack(st.session_state.get("HF_TGI_TOKEN"), st.session_state.get("model"))
|
33 |
+
st.session_state["api_key_configured"] = True
|
34 |
+
|
35 |
+
elif st.session_state.get("OPENAI_API_KEY"):
|
36 |
+
pipeline = start_haystack(st.session_state.get("OPENAI_API_KEY"), st.session_state.get("model"))
|
37 |
+
st.session_state["api_key_configured"] = True
|
38 |
+
|
39 |
search_bar, button = st.columns(2)
|
40 |
# Search bar
|
41 |
with search_bar:
|
|
|
46 |
st.write("")
|
47 |
run_pressed = st.button("Get summaries")
|
48 |
else:
|
49 |
+
st.write("Please provide your Hugging Face or OpenAI key to start using the application")
|
50 |
st.write("If you are using a smaller screen, open the sidebar from the top left to provide your token π")
|
51 |
|
52 |
if st.session_state.get("api_key_configured"):
|
utils/haystack.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
from haystack import Pipeline
|
3 |
from haystack.components.builders.prompt_builder import PromptBuilder
|
4 |
-
from haystack.components.generators import HuggingFaceTGIGenerator
|
5 |
from .hackernews_fetcher import HackernewsFetcher
|
6 |
|
7 |
-
def start_haystack(
|
8 |
prompt_template = """
|
9 |
You will be provided one or more top HakcerNews posts, followed by their URL.
|
10 |
For the posts you have, provide a short summary followed by the URL that the post can be found at.
|
@@ -18,7 +18,10 @@ Summaries:
|
|
18 |
"""
|
19 |
|
20 |
prompt_builder = PromptBuilder(template=prompt_template)
|
21 |
-
|
|
|
|
|
|
|
22 |
fetcher = HackernewsFetcher()
|
23 |
|
24 |
pipe = Pipeline()
|
@@ -34,9 +37,14 @@ Summaries:
|
|
34 |
@st.cache_data(show_spinner=True)
|
35 |
def query(top_k, _pipeline):
|
36 |
try:
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
result = replies['llm']['replies']
|
42 |
except Exception as e:
|
|
|
1 |
import streamlit as st
|
2 |
from haystack import Pipeline
|
3 |
from haystack.components.builders.prompt_builder import PromptBuilder
|
4 |
+
from haystack.components.generators import HuggingFaceTGIGenerator, OpenAIGenerator
|
5 |
from .hackernews_fetcher import HackernewsFetcher
|
6 |
|
7 |
+
def start_haystack(key, model):
|
8 |
prompt_template = """
|
9 |
You will be provided one or more top HakcerNews posts, followed by their URL.
|
10 |
For the posts you have, provide a short summary followed by the URL that the post can be found at.
|
|
|
18 |
"""
|
19 |
|
20 |
prompt_builder = PromptBuilder(template=prompt_template)
|
21 |
+
if model == "Mistral":
|
22 |
+
llm = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1", token=key)
|
23 |
+
elif model == "GPT-4":
|
24 |
+
llm = OpenAIGenerator(api_key=key, model="gpt-4")
|
25 |
fetcher = HackernewsFetcher()
|
26 |
|
27 |
pipe = Pipeline()
|
|
|
37 |
@st.cache_data(show_spinner=True)
|
38 |
def query(top_k, _pipeline):
|
39 |
try:
|
40 |
+
run_args = {"hackernews_fetcher": {"top_k": top_k}}
|
41 |
+
|
42 |
+
if st.session_state.get("model") == "Mistral":
|
43 |
+
run_args = {"hackernews_fetcher": {"top_k": top_k},
|
44 |
+
"llm": {"generation_kwargs": {"max_new_tokens": 600}}
|
45 |
+
}
|
46 |
+
|
47 |
+
replies = _pipeline.run(data=run_args)
|
48 |
|
49 |
result = replies['llm']['replies']
|
50 |
except Exception as e:
|
utils/ui.py
CHANGED
@@ -9,14 +9,18 @@ def set_initial_state():
|
|
9 |
set_state_if_absent("top_k", "How many of the top posts would you like a summary for?")
|
10 |
set_state_if_absent("result", None)
|
11 |
set_state_if_absent("haystack_started", False)
|
|
|
12 |
|
13 |
def reset_results(*args):
|
14 |
st.session_state.result = None
|
15 |
st.session_state.top_k = None
|
16 |
|
17 |
-
def
|
18 |
st.session_state["HF_TGI_TOKEN"] = api_key
|
19 |
|
|
|
|
|
|
|
20 |
def sidebar():
|
21 |
with st.sidebar:
|
22 |
# image = Image.open('logo/haystack-logo-colored.png')
|
@@ -33,16 +37,28 @@ def sidebar():
|
|
33 |
"3. Enjoy π€\n"
|
34 |
)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
st.markdown("---")
|
48 |
st.markdown(
|
|
|
9 |
set_state_if_absent("top_k", "How many of the top posts would you like a summary for?")
|
10 |
set_state_if_absent("result", None)
|
11 |
set_state_if_absent("haystack_started", False)
|
12 |
+
set_state_if_absent("model", None)
|
13 |
|
14 |
def reset_results(*args):
|
15 |
st.session_state.result = None
|
16 |
st.session_state.top_k = None
|
17 |
|
18 |
+
def set_hf_token(api_key: str):
|
19 |
st.session_state["HF_TGI_TOKEN"] = api_key
|
20 |
|
21 |
+
def set_openai_key(api_key: str):
|
22 |
+
st.session_state["OPENAI_API_KEY"] = api_key
|
23 |
+
|
24 |
def sidebar():
|
25 |
with st.sidebar:
|
26 |
# image = Image.open('logo/haystack-logo-colored.png')
|
|
|
37 |
"3. Enjoy π€\n"
|
38 |
)
|
39 |
|
40 |
+
if st.session_state.model == "Mistral":
|
41 |
+
api_key_input = st.text_input(
|
42 |
+
"Hugging Face Token",
|
43 |
+
type="password",
|
44 |
+
placeholder="Paste your Hugging Face TGI Token",
|
45 |
+
help="You can get your API key from https://platform.openai.com/account/api-keys.",
|
46 |
+
value=st.session_state.get("HF_TGI_TOKEN", ""),
|
47 |
+
)
|
48 |
+
if api_key_input:
|
49 |
+
set_hf_token(api_key_input)
|
50 |
+
|
51 |
+
elif st.session_state.model == "GPT-4":
|
52 |
+
api_key_input = st.text_input(
|
53 |
+
"OpenAI API Key",
|
54 |
+
type="password",
|
55 |
+
placeholder="Paste your OpenAI API Key",
|
56 |
+
help="You can get your API key from https://platform.openai.com/account/api-keys.",
|
57 |
+
value=st.session_state.get("OPENAI_API_KEY", ""),
|
58 |
+
)
|
59 |
+
if api_key_input:
|
60 |
+
set_openai_key(api_key_input)
|
61 |
+
|
62 |
|
63 |
st.markdown("---")
|
64 |
st.markdown(
|