Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
fn.py
CHANGED
@@ -156,6 +156,19 @@ def chatinterface_to_messages(message, history):
|
|
156 |
|
157 |
return messages
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
def chat(message, history = [], instruction = None, args = {}):
|
160 |
global tokenizer, model, cfg
|
161 |
|
@@ -168,20 +181,17 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
168 |
|
169 |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
)
|
175 |
|
176 |
generate_kwargs = dict(
|
177 |
model_inputs,
|
178 |
do_sample=True,
|
|
|
|
|
179 |
)
|
180 |
|
181 |
-
if 'fastapi' not in args or 'stream' in args and args['stream']:
|
182 |
-
generate_kwargs['streamer'] = streamer
|
183 |
-
generate_kwargs['num_beams'] = 1
|
184 |
-
|
185 |
for k in [
|
186 |
'max_new_tokens',
|
187 |
'temperature',
|
@@ -192,33 +202,21 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
192 |
if cfg[k]:
|
193 |
generate_kwargs[k] = cfg[k]
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
211 |
return content
|
212 |
-
|
213 |
-
def apply_template(messages):
|
214 |
-
global tokenizer, cfg
|
215 |
-
|
216 |
-
if cfg['chat_template']:
|
217 |
-
tokenizer.chat_template = cfg['chat_template']
|
218 |
-
|
219 |
-
if type(messages) is str:
|
220 |
-
if cfg['inst_template']:
|
221 |
-
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
|
222 |
-
return cfg['instruction'].format(input=messages)
|
223 |
-
if type(messages) is list:
|
224 |
-
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
|
|
|
156 |
|
157 |
return messages
|
158 |
|
159 |
+
def apply_template(messages):
|
160 |
+
global tokenizer, cfg
|
161 |
+
|
162 |
+
if cfg['chat_template']:
|
163 |
+
tokenizer.chat_template = cfg['chat_template']
|
164 |
+
|
165 |
+
if type(messages) is str:
|
166 |
+
if cfg['inst_template']:
|
167 |
+
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
|
168 |
+
return cfg['instruction'].format(input=messages)
|
169 |
+
if type(messages) is list:
|
170 |
+
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
|
171 |
+
|
172 |
def chat(message, history = [], instruction = None, args = {}):
|
173 |
global tokenizer, model, cfg
|
174 |
|
|
|
181 |
|
182 |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
183 |
|
184 |
+
streamer = TextIteratorStreamer(
|
185 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
|
186 |
+
)
|
|
|
187 |
|
188 |
generate_kwargs = dict(
|
189 |
model_inputs,
|
190 |
do_sample=True,
|
191 |
+
streamer=streamer,
|
192 |
+
num_beams=1,
|
193 |
)
|
194 |
|
|
|
|
|
|
|
|
|
195 |
for k in [
|
196 |
'max_new_tokens',
|
197 |
'temperature',
|
|
|
202 |
if cfg[k]:
|
203 |
generate_kwargs[k] = cfg[k]
|
204 |
|
205 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
206 |
+
t.start()
|
207 |
+
|
208 |
+
model_output = ""
|
209 |
+
for new_text in streamer:
|
210 |
+
model_output += new_text
|
211 |
+
if 'fastapi' in args:
|
212 |
+
# fastapiは差分だけを返して欲しい
|
213 |
+
yield new_text
|
214 |
+
else:
|
215 |
+
# gradioは常に全文を返して欲しい
|
216 |
+
yield model_output
|
217 |
+
|
218 |
+
def infer(message, history = [], instruction = None, args = {}):
|
219 |
+
content = ''
|
220 |
+
for s in chat(message, history, instruction, args):
|
221 |
+
content += s
|
222 |
return content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -40,5 +40,5 @@ async def api_infer(args: dict):
|
|
40 |
media_type="text/event-stream",
|
41 |
)
|
42 |
else:
|
43 |
-
content = fn.
|
44 |
return {'content': content}
|
|
|
40 |
media_type="text/event-stream",
|
41 |
)
|
42 |
else:
|
43 |
+
content = fn.infer(args['input'], [], args['instruct'], args)
|
44 |
return {'content': content}
|