project-baize commited on
Commit
77ff0b3
1 Parent(s): 4bcccc4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +219 -0
  2. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import os
3
+ import logging
4
+ import sys
5
+ import gradio as gr
6
+ import torch
7
+ from app_modules.utils import *
8
+ from app_modules.presets import *
9
+ from app_modules.overwrites import *
10
+
11
+ logging.basicConfig(
12
+ level=logging.DEBUG,
13
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
14
+ )
15
+
16
+ base_model = "decapoda-research/llama-7b-hf"
17
+ adapter_model = "project-baize/baize-lora-7B"
18
+ #tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
19
+
20
+
21
+ def predict(text,
22
+ chatbot,
23
+ history,
24
+ top_p,
25
+ temperature,
26
+ max_length_tokens,
27
+ max_context_length_tokens,):
28
+ if text=="":
29
+ return history,history,"Empty Context"
30
+
31
+ inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
32
+ if inputs is False:
33
+ return [[x[0],convert_to_markdown(x[1])] for x in history]+[[text,"Sorry, the input is too long."]],history,"Generate Fail"
34
+ else:
35
+ prompt,inputs=inputs
36
+ begin_length = len(prompt)
37
+
38
+ input_ids = inputs["input_ids"].to(device)
39
+
40
+ with torch.no_grad():
41
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
42
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
43
+ if "[|Human|]" in x:
44
+ x = x[:x.index("[|Human|]")].strip()
45
+ if "[|AI|]" in x:
46
+ x = x[:x.index("[|AI|]")].strip()
47
+ x = x.strip(" ")
48
+ a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
49
+ yield a, b, "Generating……"
50
+ if shared_state.interrupted:
51
+ shared_state.recover()
52
+ try:
53
+ yield a, b, "Stop Success"
54
+ return
55
+ except:
56
+ pass
57
+ print(prompt)
58
+ print(x)
59
+ print("="*80)
60
+ try:
61
+ yield a,b,"Generate Success"
62
+ except:
63
+ pass
64
+
65
+ def retry(
66
+ text,
67
+ chatbot,
68
+ history,
69
+ top_p,
70
+ temperature,
71
+ max_length_tokens,
72
+ max_context_length_tokens,
73
+ ):
74
+ logging.info("Retry……")
75
+ if len(history) == 0:
76
+ yield chatbot, history, f"Empty context"
77
+ return
78
+ chatbot.pop()
79
+ inputs = history.pop()[0]
80
+ for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens):
81
+ yield x
82
+
83
+
84
+ gr.Chatbot.postprocess = postprocess
85
+
86
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
87
+ customCSS = f.read()
88
+
89
+ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
90
+ history = gr.State([])
91
+ user_question = gr.State("")
92
+ with gr.Row():
93
+ gr.HTML(title)
94
+ status_display = gr.Markdown("Success", elem_id="status_display")
95
+ gr.Markdown(description_top)
96
+ with gr.Row(scale=1).style(equal_height=True):
97
+ with gr.Column(scale=5):
98
+ with gr.Row(scale=1):
99
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="80%")
100
+ with gr.Row(scale=1):
101
+ with gr.Column(scale=12):
102
+ user_input = gr.Textbox(
103
+ show_label=False, placeholder="Enter text"
104
+ ).style(container=False)
105
+ with gr.Column(min_width=70, scale=1):
106
+ submitBtn = gr.Button("Send")
107
+ with gr.Column(min_width=70, scale=1):
108
+ cancelBtn = gr.Button("Stop")
109
+
110
+ with gr.Row(scale=1):
111
+ emptyBtn = gr.Button(
112
+ "🧹 New Conversation",
113
+ )
114
+ retryBtn = gr.Button("🔄 Regenerate")
115
+ delLastBtn = gr.Button("🗑️ Remove Last Turn")
116
+ with gr.Column():
117
+ with gr.Column(min_width=50, scale=1):
118
+ with gr.Tab(label="Parameter Setting"):
119
+ gr.Markdown("# Parameters")
120
+ top_p = gr.Slider(
121
+ minimum=-0,
122
+ maximum=1.0,
123
+ value=0.95,
124
+ step=0.05,
125
+ interactive=True,
126
+ label="Top-p",
127
+ )
128
+ temperature = gr.Slider(
129
+ minimum=-0,
130
+ maximum=2.0,
131
+ value=1,
132
+ step=0.1,
133
+ interactive=True,
134
+ label="Temperature",
135
+ )
136
+ max_length_tokens = gr.Slider(
137
+ minimum=0,
138
+ maximum=512,
139
+ value=512,
140
+ step=8,
141
+ interactive=True,
142
+ label="Max Generation Tokens",
143
+ )
144
+ max_context_length_tokens = gr.Slider(
145
+ minimum=0,
146
+ maximum=4096,
147
+ value=2048,
148
+ step=128,
149
+ interactive=True,
150
+ label="Max History Tokens",
151
+ )
152
+ gr.Markdown(description)
153
+
154
+ predict_args = dict(
155
+ fn=predict,
156
+ inputs=[
157
+ user_question,
158
+ chatbot,
159
+ history,
160
+ top_p,
161
+ temperature,
162
+ max_length_tokens,
163
+ max_context_length_tokens,
164
+ ],
165
+ outputs=[chatbot, history, status_display],
166
+ show_progress=True,
167
+ )
168
+ retry_args = dict(
169
+ fn=retry,
170
+ inputs=[
171
+ user_input,
172
+ chatbot,
173
+ history,
174
+ top_p,
175
+ temperature,
176
+ max_length_tokens,
177
+ max_context_length_tokens,
178
+ ],
179
+ outputs=[chatbot, history, status_display],
180
+ show_progress=True,
181
+ )
182
+
183
+ reset_args = dict(
184
+ fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
185
+ )
186
+
187
+ # Chatbot
188
+ cancelBtn.click(cancel_outputing, [], [ status_display])
189
+ transfer_input_args = dict(
190
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
191
+ )
192
+
193
+ user_input.submit(**transfer_input_args).then(**predict_args)
194
+
195
+ submitBtn.click(**transfer_input_args).then(**predict_args)
196
+
197
+ emptyBtn.click(
198
+ reset_state,
199
+ outputs=[chatbot, history, status_display],
200
+ show_progress=True,
201
+ )
202
+ emptyBtn.click(**reset_args)
203
+
204
+ retryBtn.click(**retry_args)
205
+
206
+ delLastBtn.click(
207
+ delete_last_conversation,
208
+ [chatbot, history],
209
+ [chatbot, history, status_display],
210
+ show_progress=True,
211
+ )
212
+
213
+ demo.title = "Baize"
214
+
215
+ if __name__ == "__main__":
216
+ reload_javascript()
217
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
218
+ share=True, favicon_path="./assets/favicon.ico", inbrowser=True
219
+ )
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ mdtex2html
3
+ pypinyin
4
+ tiktoken
5
+ socksio
6
+ tqdm
7
+ colorama
8
+ duckduckgo_search
9
+ Pygments
10
+ llama_index
11
+ langchain
12
+ markdown
13
+ markdown2
14
+