aaabiao commited on
Commit
4e434e6
1 Parent(s): 38e817c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -5,7 +5,25 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -17,6 +35,7 @@ if torch.cuda.is_available():
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
  @spaces.GPU
 
20
  def generate(
21
  message: str,
22
  chat_history: list[tuple[str, str]],
@@ -40,6 +59,10 @@ def generate(
40
  input_ids = input_ids.to(model.device)
41
 
42
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
43
  generate_kwargs = dict(
44
  {"input_ids": input_ids},
45
  streamer=streamer,
@@ -48,6 +71,7 @@ def generate(
48
  top_p=top_p,
49
  temperature=temperature,
50
  num_beams=1,
 
51
  repetition_penalty=repetition_penalty,
52
  )
53
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -56,11 +80,7 @@ def generate(
56
  outputs = []
57
  for text in streamer:
58
  outputs.append(text)
59
- generated_text = "".join(outputs)
60
- if "<s>" in generated_text:
61
- yield generated_text[:generated_text.index("<s>")+3]
62
- break
63
- yield generated_text
64
 
65
 
66
  chat_interface = gr.ChatInterface(
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ StoppingCriteria,
12
+ StoppingCriteriaList,
13
+ TextIteratorStreamer,
14
+ )
15
+
16
+ class StoppingCriteriaSub(StoppingCriteria):
17
+ def __init__(self, stops = [], encounters=1):
18
+ super().__init__()
19
+ # self.stops = [stop.to("cuda") for stop in stops]
20
+
21
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
22
+ last_token = input_ids[0][-1]
23
+ for stop in self.stops:
24
+ if tokenizer.decode(stop) == tokenizer.decode(last_token):
25
+ return True
26
+ return False
27
 
28
  MAX_MAX_NEW_TOKENS = 2048
29
  DEFAULT_MAX_NEW_TOKENS = 1024
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
 
37
  @spaces.GPU
38
+ User
39
  def generate(
40
  message: str,
41
  chat_history: list[tuple[str, str]],
 
59
  input_ids = input_ids.to(model.device)
60
 
61
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
62
+ stop_words = ["</s>"]
63
+ stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
64
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
65
+
66
  generate_kwargs = dict(
67
  {"input_ids": input_ids},
68
  streamer=streamer,
 
71
  top_p=top_p,
72
  temperature=temperature,
73
  num_beams=1,
74
+ stopping_criteria=stopping_criteria,
75
  repetition_penalty=repetition_penalty,
76
  )
77
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
80
  outputs = []
81
  for text in streamer:
82
  outputs.append(text)
83
+ yield "".join(outputs)
 
 
 
 
84
 
85
 
86
  chat_interface = gr.ChatInterface(