Spaces:
Sleeping
Sleeping
Update pages/1_Simple_Chat_UI.py
Browse files
pages/1_Simple_Chat_UI.py
CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
3 |
|
4 |
def load_model_tokenizer(model_name, hf_api_key):
|
5 |
-
if model_name == "
|
6 |
-
model_name="
|
7 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_api_key)
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer=hf_api_key)
|
9 |
-
elif model_name == "
|
10 |
-
model_name = "
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
return (model,tokenizer)
|
@@ -38,7 +38,7 @@ with st.sidebar:
|
|
38 |
if "messages" not in st.session_state.keys():
|
39 |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
|
40 |
|
41 |
-
model_name = st.radio("Select model to chat", options=["
|
42 |
model, tokenizer = load_model_tokenizer(model_name, hf_api_key)
|
43 |
|
44 |
for message in st.session_state.messages:
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
3 |
|
4 |
def load_model_tokenizer(model_name, hf_api_key):
|
5 |
+
if model_name == "LLaMa-2B":
|
6 |
+
model_name="llmware/bling-sheared-llama-2.7b-0.1"
|
7 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_api_key)
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer=hf_api_key)
|
9 |
+
elif model_name == "Red-Pajamas-3b":
|
10 |
+
model_name = "llmware/bling-red-pajamas-3b-0.1"
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
return (model,tokenizer)
|
|
|
38 |
if "messages" not in st.session_state.keys():
|
39 |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
|
40 |
|
41 |
+
model_name = st.radio("Select model to chat", options=["LLaMa-2B", "Red-Pajamas-3b"], horizontal=True, key='model_selection')
|
42 |
model, tokenizer = load_model_tokenizer(model_name, hf_api_key)
|
43 |
|
44 |
for message in st.session_state.messages:
|