File size: 2,070 Bytes
d1a642c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf99e4b
 
d1a642c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4a3a2
d1a642c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
from streamlit_chat import message
from model import ChatGLMModel, chat_template


# page state

@st.cache_resource
def create_model():
    return ChatGLMModel()

with st.spinner("加载模型中..."):
    model = create_model()


if "history" not in st.session_state:
    st.session_state["history"] = []


# parameters

with st.sidebar:
    st.markdown("## 采样参数")

    max_tokens = st.number_input("max_tokens", min_value=1, max_value=500, value=200)
    temperature = st.number_input("temperature", min_value=0.1, max_value=4.0, value=1.0)
    top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.7)
    top_k = st.number_input("top_k", min_value=1, max_value=500, value=50)

    if st.button("清空上下文"):
        st.session_state.message = ""
        st.session_state.history = []

    st.markdown("""
    [ChatGLM](https://huggingface.co/THUDM/chatglm-6b) + [ONNXRuntime](https://onnxruntime.ai/)
    """)


# main body

st.markdown("## ChatGLM + ONNXRuntime")

history: list[tuple[str, str]] = st.session_state.history

if len(history) == 0:
    st.caption("请在下方输入消息开始会话")


for idx, (question, answer) in enumerate(history):
    message(question, is_user=True, key=f"history_question_{idx}")
    st.write(answer)
    st.markdown("---")


next_answer = st.container()

question = st.text_area(label="消息", key="message")

if st.button("发送") and len(question.strip()):
    with next_answer:
        message(question, is_user=True, key="message_question")
        with st.spinner("正在回复中"):
            with st.empty():
                prompt = chat_template(history, question)
                for answer in model.generate_iterate(
                    prompt,
                    max_generated_tokens=max_tokens,
                    top_k=top_k,
                    top_p=top_p,
                    temperature=temperature,
                ):
                    st.write(answer)
        st.markdown("---")

    st.session_state.history = history + [(question, answer)]