vilarin commited on
Commit
be961e6
1 Parent(s): c257f19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -3,9 +3,9 @@ import torch
3
  from PIL import Image
4
  import gradio as gr
5
  import spaces
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
- from huggingface_hub.inference._generated.types import TextGenerationStreamOutput, TextGenerationStreamOutputToken
8
  import os
 
9
  from huggingface_hub import hf_hub_download
10
 
11
 
@@ -109,35 +109,28 @@ def stream_chat(message, history: list, system: str, temperature: float, max_new
109
  return_tensors="pt"
110
  ).to(model.device)
111
  images = None
112
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
113
 
114
  generate_kwargs = dict(
115
  input_ids=input_ids,
116
- streamer=streamer,
117
  max_new_tokens=max_new_tokens,
118
  temperature=temperature,
119
  do_sample=True,
 
120
  eos_token_id=terminators,
121
  images=images
122
  )
123
  if temperature == 0:
124
  generate_kwargs["do_sample"] = False
125
 
126
- t = Thread(target=model.generate, kwargs=generate_kwargs)
127
- t.start()
128
  input_token_len = input_ids.shape[1]
129
-
130
- output = ""
131
- for next_text in streamer:
132
- yield TextGenerationStreamOutput(
133
- index=0,
134
- token=TextGenerationStreamOutputToken(
135
- id=0,
136
- logprob=0,
137
- text=next_text,
138
- special=False,
139
- )
140
- )
141
 
142
 
143
  chatbot = gr.Chatbot(height=450)
 
3
  from PIL import Image
4
  import gradio as gr
5
  import spaces
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
7
  import os
8
+ import time
9
  from huggingface_hub import hf_hub_download
10
 
11
 
 
109
  return_tensors="pt"
110
  ).to(model.device)
111
  images = None
 
112
 
113
  generate_kwargs = dict(
114
  input_ids=input_ids,
 
115
  max_new_tokens=max_new_tokens,
116
  temperature=temperature,
117
  do_sample=True,
118
+ num_beams=1,
119
  eos_token_id=terminators,
120
  images=images
121
  )
122
  if temperature == 0:
123
  generate_kwargs["do_sample"] = False
124
 
125
+ output_ids=model.generate(**generate_kwargs)
 
126
  input_token_len = input_ids.shape[1]
127
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
128
+ outputs = outputs.strip()
129
+
130
+ for i in range(len(outputs)):
131
+ time.sleep(0.05)
132
+ yield outputs[: i + 1]
133
+
 
 
 
 
 
134
 
135
 
136
  chatbot = gr.Chatbot(height=450)