0x7o commited on
Commit
1d72a65
1 Parent(s): dcbab3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import spaces
5
 
@@ -14,31 +14,43 @@ model = AutoModelForCausalLM.from_pretrained(
14
  )
15
 
16
  @spaces.GPU
17
- def predict(message, history, max_tokens, temperature, top_p):
18
- # Формирование чата из истории и нового сообщения
19
- chat = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
20
- for i, (m, _) in enumerate(history)] + [{"role": "user", "content": message}]
21
-
22
- # Применение шаблона чата
23
- prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
24
-
25
- # Кодирование входных данных
26
- inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
27
-
28
- # Генерация ответа
29
- outputs = model.generate(
30
- input_ids=inputs,
31
- max_new_tokens=max_tokens,
 
 
 
 
 
 
 
 
32
  do_sample=True,
33
- temperature=temperature,
34
  top_p=top_p,
 
 
35
  )
36
-
37
- # Декодирование результата
38
- response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
39
-
40
- return response.strip().replace("assistant", "", 1)
 
 
41
 
 
42
  # Настройка интерфейса Gradio
43
  iface = gr.ChatInterface(
44
  predict,
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  import torch
4
  import spaces
5
 
 
14
  )
15
 
16
  @spaces.GPU
17
+ def generate(
18
+ message: str,
19
+ chat_history: list[tuple[str, str]],
20
+ max_new_tokens: int = 1024,
21
+ temperature: float = 0.6,
22
+ top_p: float = 0.9
23
+ ) -> Iterator[str]:
24
+ conversation = []
25
+ for user, assistant in chat_history:
26
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
27
+ conversation.append({"role": "user", "content": message})
28
+
29
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
30
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
31
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
32
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
33
+ input_ids = input_ids.to(model.device)
34
+
35
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
36
+ generate_kwargs = dict(
37
+ {"input_ids": input_ids},
38
+ streamer=streamer,
39
+ max_new_tokens=max_new_tokens,
40
  do_sample=True,
 
41
  top_p=top_p,
42
+ temperature=temperature,
43
+ num_beams=1
44
  )
45
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
46
+ t.start()
47
+
48
+ outputs = []
49
+ for text in streamer:
50
+ outputs.append(text)
51
+ yield "".join(outputs)
52
 
53
+
54
  # Настройка интерфейса Gradio
55
  iface = gr.ChatInterface(
56
  predict,