|
import streamlit as st |
|
import requests |
|
import json |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
def reset_conversation(): |
|
''' |
|
Resets Conversation |
|
''' |
|
st.session_state.conversation = [] |
|
st.session_state.messages = [] |
|
return None |
|
|
|
|
|
model_links = { |
|
"Addiction recovery AI": "NousResearch/Nous-Hermes-2-Yi-34B", |
|
"Mental health AI": "NousResearch/Nous-Hermes-2-Yi-34B" |
|
} |
|
|
|
model_info = { |
|
"Addiction recovery AI": { |
|
'description': "This model provides support and guidance for individuals on their addiction recovery journey.", |
|
'logo': 'https://example.com/addiction_recovery_logo.png' |
|
}, |
|
"Mental health AI": { |
|
'description': "This model offers assistance and resources for individuals dealing with mental health concerns.", |
|
'logo': 'https://example.com/mental_health_logo.png' |
|
} |
|
} |
|
|
|
|
|
def interact_with_huggingface_model(messages, model): |
|
|
|
pass |
|
|
|
|
|
def interact_with_together_api(messages): |
|
all_messages = [] |
|
|
|
if not messages: |
|
all_messages.append({"role": "user", "content": ""}) |
|
history = [("", "")] |
|
|
|
for human, assistant in messages: |
|
all_messages.append({"role": "user", "content": human}) |
|
all_messages.append({"role": "assistant", "content": assistant}) |
|
|
|
all_messages.append({"role": "user", "content": messages[-1][1]}) |
|
|
|
url = "https://api.together.xyz/v1/chat/completions" |
|
payload = { |
|
"model": "NousResearch/Nous-Hermes-2-Yi-34B", |
|
"temperature": 1.05, |
|
"top_p": 0.9, |
|
"top_k": 50, |
|
"repetition_penalty": 1, |
|
"n": 1, |
|
"messages": all_messages, |
|
} |
|
|
|
TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY') |
|
headers = { |
|
"accept": "application/json", |
|
"content-type": "application/json", |
|
"Authorization": f"Bearer {TOGETHER_API_KEY}", |
|
} |
|
|
|
response = requests.post(url, json=payload, headers=headers) |
|
response.raise_for_status() |
|
|
|
|
|
response_data = response.json() |
|
assistant_response = response_data["choices"][0]["message"]["content"] |
|
|
|
return assistant_response |
|
|
|
|
|
selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys())) |
|
st.sidebar.button('Reset Chat', on_click=reset_conversation) |
|
|
|
|
|
st.sidebar.write(f"You're now chatting with **{selected_model}**") |
|
st.sidebar.markdown(model_info[selected_model]['description']) |
|
st.sidebar.image(model_info[selected_model]['logo']) |
|
st.sidebar.markdown("*Generated content may be inaccurate or false.*") |
|
st.sidebar.markdown("\nLearn how to build this chatbot [here](https://ngebodh.github.io/projects/2024-03-05/).") |
|
st.sidebar.markdown("\nRun into issues? Try the [back-up](https://huggingface.co/spaces/ngebodh/SimpleChatbot-Backup).") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message[0]): |
|
st.markdown(message[1]) |
|
|
|
|
|
if prompt := st.chat_input(f"Hi, I'm {selected_model}, ask me a question"): |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
st.session_state.messages.append(("user", prompt)) |
|
|
|
|
|
if selected_model == "Nous-Hermes-2-Yi-34B": |
|
assistant_response = interact_with_together_api(st.session_state.messages) |
|
else: |
|
assistant_response = interact_with_huggingface_model(st.session_state.messages, model_links[selected_model]) |
|
|
|
|
|
with st.empty(): |
|
st.markdown("AI is typing...") |
|
st.empty() |
|
st.markdown(assistant_response) |
|
|
|
|
|
st.session_state.messages.append(("assistant", assistant_response)) |
|
|