pszemraj commited on
Commit
27da979
·
1 Parent(s): 0240ed4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -16
app.py CHANGED
@@ -1,4 +1,11 @@
1
  from threading import Thread
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import gradio as gr
@@ -6,24 +13,39 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStrea
6
 
7
  model_id = "pszemraj/flan-t5-large-instruct-dolly_hhrlhf"
8
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
- print("Running on device:", torch_device)
10
- print("CPU threads:", torch.get_num_threads())
11
 
12
 
13
  if torch_device == "cuda":
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
 
 
15
  else:
16
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
 
20
- def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_generation_config=False):
 
 
 
 
 
 
 
 
 
 
 
21
  # Get the model and tokenizer, and tokenize the user text.
22
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
23
 
24
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
25
  # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
26
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
 
27
  generate_kwargs = dict(
28
  model_inputs,
29
  streamer=streamer,
@@ -32,7 +54,8 @@ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_gen
32
  num_beams=1,
33
  top_p=top_p,
34
  temperature=float(temperature),
35
- top_k=top_k
 
36
  )
37
  t = Thread(target=model.generate, kwargs=generate_kwargs)
38
  t.start()
@@ -42,15 +65,18 @@ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_gen
42
  for new_text in streamer:
43
  model_output += new_text
44
  yield model_output
 
45
  return model_output
46
 
47
 
48
  def reset_textbox():
49
- return gr.update(value='')
50
 
51
 
52
  with gr.Blocks() as demo:
53
- duplicate_link = "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
 
 
54
  gr.Markdown(
55
  "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
56
  "This demo showcases the use of the "
@@ -65,26 +91,54 @@ with gr.Blocks() as demo:
65
  with gr.Column(scale=4):
66
  user_text = gr.Textbox(
67
  placeholder="Write an email about an alpaca that likes flan",
68
- label="User input"
69
  )
70
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
71
  button_submit = gr.Button(value="Submit")
72
 
73
  with gr.Column(scale=1):
74
  max_new_tokens = gr.Slider(
75
- minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
 
 
 
 
 
76
  )
77
  top_p = gr.Slider(
78
- minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
 
 
 
 
 
79
  )
80
  top_k = gr.Slider(
81
- minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
 
 
 
 
 
82
  )
83
  temperature = gr.Slider(
84
- minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
 
 
 
 
 
85
  )
86
 
87
- user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
88
- button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
 
 
 
 
 
 
 
 
89
 
90
- demo.queue(max_size=32).launch(enable_queue=True)
 
1
  from threading import Thread
2
+ import logging
3
+ import time
4
+
5
+ logging.basicConfig(
6
+ level=logging.INFO,
7
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
8
+ )
9
 
10
  import torch
11
  import gradio as gr
 
13
 
14
  model_id = "pszemraj/flan-t5-large-instruct-dolly_hhrlhf"
15
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ logging.info("Running on device:", torch_device)
17
+ logging.info("CPU threads:", torch.get_num_threads())
18
 
19
 
20
  if torch_device == "cuda":
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(
22
+ model_id, load_in_8bit=True, device_map="auto"
23
+ )
24
  else:
25
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
28
 
29
+ def run_generation(
30
+ user_text,
31
+ top_p,
32
+ temperature,
33
+ top_k,
34
+ max_new_tokens,
35
+ no_repeat_ngram_size=4,
36
+ length_penalty=1.0,
37
+ repetition_penalty=1.1,
38
+ use_generation_config=False,
39
+ ):
40
+ st = time.perf_counter()
41
  # Get the model and tokenizer, and tokenize the user text.
42
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
43
 
44
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
45
  # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
46
+ streamer = TextIteratorStreamer(
47
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
48
+ )
49
  generate_kwargs = dict(
50
  model_inputs,
51
  streamer=streamer,
 
54
  num_beams=1,
55
  top_p=top_p,
56
  temperature=float(temperature),
57
+ top_k=top_k,
58
+ no_repeat_ngram_size=no_repeat_ngram_size,
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
 
65
  for new_text in streamer:
66
  model_output += new_text
67
  yield model_output
68
+ logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3)))
69
  return model_output
70
 
71
 
72
  def reset_textbox():
73
+ return gr.update(value="")
74
 
75
 
76
  with gr.Blocks() as demo:
77
+ duplicate_link = (
78
+ "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
79
+ )
80
  gr.Markdown(
81
  "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
82
  "This demo showcases the use of the "
 
91
  with gr.Column(scale=4):
92
  user_text = gr.Textbox(
93
  placeholder="Write an email about an alpaca that likes flan",
94
+ label="User input",
95
  )
96
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
97
  button_submit = gr.Button(value="Submit")
98
 
99
  with gr.Column(scale=1):
100
  max_new_tokens = gr.Slider(
101
+ minimum=1,
102
+ maximum=1000,
103
+ value=250,
104
+ step=1,
105
+ interactive=True,
106
+ label="Max New Tokens",
107
  )
108
  top_p = gr.Slider(
109
+ minimum=0.05,
110
+ maximum=1.0,
111
+ value=0.95,
112
+ step=0.05,
113
+ interactive=True,
114
+ label="Top-p (nucleus sampling)",
115
  )
116
  top_k = gr.Slider(
117
+ minimum=1,
118
+ maximum=50,
119
+ value=50,
120
+ step=1,
121
+ interactive=True,
122
+ label="Top-k",
123
  )
124
  temperature = gr.Slider(
125
+ minimum=0.1,
126
+ maximum=5.0,
127
+ value=0.8,
128
+ step=0.1,
129
+ interactive=True,
130
+ label="Temperature",
131
  )
132
 
133
+ user_text.submit(
134
+ run_generation,
135
+ [user_text, top_p, temperature, top_k, max_new_tokens],
136
+ model_output,
137
+ )
138
+ button_submit.click(
139
+ run_generation,
140
+ [user_text, top_p, temperature, top_k, max_new_tokens],
141
+ model_output,
142
+ )
143
 
144
+ demo.queue(max_size=32).launch(enable_queue=True)