File size: 2,070 Bytes
d1a642c bf99e4b d1a642c 5f4a3a2 d1a642c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import streamlit as st
from streamlit_chat import message
from model import ChatGLMModel, chat_template
# page state
@st.cache_resource
def create_model():
return ChatGLMModel()
with st.spinner("加载模型中..."):
model = create_model()
if "history" not in st.session_state:
st.session_state["history"] = []
# parameters
with st.sidebar:
st.markdown("## 采样参数")
max_tokens = st.number_input("max_tokens", min_value=1, max_value=500, value=200)
temperature = st.number_input("temperature", min_value=0.1, max_value=4.0, value=1.0)
top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.7)
top_k = st.number_input("top_k", min_value=1, max_value=500, value=50)
if st.button("清空上下文"):
st.session_state.message = ""
st.session_state.history = []
st.markdown("""
[ChatGLM](https://huggingface.co/THUDM/chatglm-6b) + [ONNXRuntime](https://onnxruntime.ai/)
""")
# main body
st.markdown("## ChatGLM + ONNXRuntime")
history: list[tuple[str, str]] = st.session_state.history
if len(history) == 0:
st.caption("请在下方输入消息开始会话")
for idx, (question, answer) in enumerate(history):
message(question, is_user=True, key=f"history_question_{idx}")
st.write(answer)
st.markdown("---")
next_answer = st.container()
question = st.text_area(label="消息", key="message")
if st.button("发送") and len(question.strip()):
with next_answer:
message(question, is_user=True, key="message_question")
with st.spinner("正在回复中"):
with st.empty():
prompt = chat_template(history, question)
for answer in model.generate_iterate(
prompt,
max_generated_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature,
):
st.write(answer)
st.markdown("---")
st.session_state.history = history + [(question, answer)]
|