aaabiao commited on
Commit
1d8d33a
1 Parent(s): 25468b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -5,25 +5,8 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import (
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- StoppingCriteria,
12
- StoppingCriteriaList,
13
- TextIteratorStreamer,
14
- )
15
-
16
- class StoppingCriteriaSub(StoppingCriteria):
17
- def __init__(self, stops = [], encounters=1):
18
- super().__init__()
19
- self.stops = [stop.to("cuda") for stop in stops]
20
-
21
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
22
- last_token = input_ids[0][-1]
23
- for stop in self.stops:
24
- if tokenizer.decode(stop) == tokenizer.decode(last_token):
25
- return True
26
- return False
27
 
28
  MAX_MAX_NEW_TOKENS = 2048
29
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -57,21 +40,22 @@ def generate(
57
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
  input_ids = input_ids.to(model.device)
59
 
 
 
60
  stop_words = ["</s>"]
61
  stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
62
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
63
 
64
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
65
- generate_kwargs = {
66
- "input_ids": input_ids,
67
- "streamer": streamer,
68
- "max_new_tokens": max_new_tokens,
69
- "do_sample": True,
70
- "top_p": top_p,
71
- "temperature": temperature,
72
- "stopping_criteria": stopping_criteria,
73
- "repetition_penalty": repetition_penalty,
74
- }
75
  t = Thread(target=model.generate, kwargs=generate_kwargs)
76
  t.start()
77
 
@@ -113,7 +97,7 @@ chat_interface = gr.ChatInterface(
113
  value=1.1,
114
  ),
115
  ],
116
- stop_words=stop_words, # Set the stop words
117
  examples=[
118
  ["Hello there! How are you doing?"],
119
  ["Can you explain briefly to me what is the Python programming language?"],
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from transformers.generation_stopping_criteria import StoppingCriteria, StoppingCriteriaList
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
 
40
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
41
  input_ids = input_ids.to(model.device)
42
 
43
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
44
+
45
  stop_words = ["</s>"]
46
  stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
47
+ stopping_criteria = StoppingCriteriaList([StoppingCriteria(stops=stop_words_ids)])
48
 
49
+ generate_kwargs = dict(
50
+ input_ids=model_inputs,
51
+ streamer=streamer,
52
+ max_new_tokens=max_new_tokens,
53
+ do_sample=True,
54
+ top_p=top_p,
55
+ temperature=temperature,
56
+ stopping_criteria=stopping_criteria,
57
+ repetition_penalty=repetition_penalty,
58
+ )
 
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()
61
 
 
97
  value=1.1,
98
  ),
99
  ],
100
+ stop_button=True, # Changed stop button to True
101
  examples=[
102
  ["Hello there! How are you doing?"],
103
  ["Can you explain briefly to me what is the Python programming language?"],