merve HF staff commited on
Commit
140a766
1 Parent(s): 18bcb78

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gradio as gr
3
+ from transformers import AutoProcessor, Idefics2ForConditionalGeneration, TextIteratorStreamer
4
+ from threading import Thread
5
+ import re
6
+ import time
7
+ from PIL import Image
8
+ import torch
9
+ import spaces
10
+
11
+ PROCESSOR = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
12
+
13
+ model = Idefics2ForConditionalGeneration.from_pretrained(
14
+ "HuggingFaceM4/idefics2-8b",
15
+ torch_dtype=torch.bfloat16,
16
+ _attn_implementation="flash_attention_2",
17
+ trust_remote_code=True).to("cuda")
18
+
19
+
20
+
21
+ def turn_is_pure_media(turn):
22
+ return turn[1] is None
23
+ def format_user_prompt_with_im_history_and_system_conditioning(
24
+ user_prompt, chat_history
25
+ ):
26
+ """
27
+ Produces the resulting list that needs to go inside the processor.
28
+ It handles the potential image(s), the history and the system conditionning.
29
+ """
30
+ resulting_messages = copy.deepcopy([])
31
+ resulting_images = []
32
+
33
+ # Format history
34
+ for turn in chat_history:
35
+ if not resulting_messages or (resulting_messages and resulting_messages[-1]["role"] != "user"):
36
+ resulting_messages.append(
37
+ {
38
+ "role": "user",
39
+ "content": [],
40
+ }
41
+ )
42
+
43
+ if turn_is_pure_media(turn):
44
+ media = turn[0][0]
45
+ resulting_messages[-1]["content"].append({"type": "image"})
46
+ resulting_images.append(Image.open(media))
47
+ else:
48
+ user_utterance, assistant_utterance = turn
49
+ resulting_messages[-1]["content"].append(
50
+ {"type": "text", "text": user_utterance.strip()}
51
+ )
52
+ resulting_messages.append(
53
+ {
54
+ "role": "assistant",
55
+ "content": [
56
+ {"type": "text", "text": user_utterance.strip()}
57
+ ]
58
+ }
59
+ )
60
+
61
+ # Format current input
62
+ if not user_prompt["files"]:
63
+ resulting_messages.append(
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {"type": "text", "text": user_prompt['text']}
68
+ ],
69
+ }
70
+ )
71
+ else:
72
+ # Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice.
73
+ resulting_messages.append(
74
+ {
75
+ "role": "user",
76
+ "content": [{"type": "image"}] * len(user_prompt['files']) + [
77
+ {"type": "text", "text": user_prompt['text']}
78
+ ]
79
+ }
80
+ )
81
+ for im in user_prompt["files"]:
82
+ print(im)
83
+ if isinstance(im, str):
84
+
85
+ resulting_images.extend([Image.open(im)])
86
+ elif isinstance(im, dict):
87
+ resulting_images.extend([Image.open(im['path'])])
88
+
89
+
90
+ return resulting_messages, resulting_images
91
+
92
+
93
+ def extract_images_from_msg_list(msg_list):
94
+ all_images = []
95
+ for msg in msg_list:
96
+ for c_ in msg["content"]:
97
+ if isinstance(c_, Image.Image):
98
+ all_images.append(c_)
99
+ return all_images
100
+
101
+ @spaces.GPU(duration=180)
102
+ def model_inference(
103
+ user_prompt,
104
+ chat_history,
105
+ decoding_strategy,
106
+ temperature,
107
+ max_new_tokens,
108
+ repetition_penalty,
109
+ top_p,
110
+ ):
111
+ if user_prompt["text"].strip() == "" and not user_prompt["files"]:
112
+ gr.Error("Please input a query and optionally image(s).")
113
+
114
+ if user_prompt["text"].strip() == "" and user_prompt["files"]:
115
+ gr.Error("Please input a text query along the image(s).")
116
+
117
+
118
+ streamer = TextIteratorStreamer(
119
+ PROCESSOR.tokenizer,
120
+ skip_prompt=True,
121
+ timeout=5.,
122
+ )
123
+
124
+ # Common parameters to all decoding strategies
125
+ # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
126
+ generation_args = {
127
+ "max_new_tokens": max_new_tokens,
128
+ "repetition_penalty": repetition_penalty,
129
+ "streamer": streamer,
130
+ }
131
+
132
+ assert decoding_strategy in [
133
+ "Greedy",
134
+ "Top P Sampling",
135
+ ]
136
+ if decoding_strategy == "Greedy":
137
+ generation_args["do_sample"] = False
138
+ elif decoding_strategy == "Top P Sampling":
139
+ generation_args["temperature"] = temperature
140
+ generation_args["do_sample"] = True
141
+ generation_args["top_p"] = top_p
142
+
143
+ # Creating model inputs
144
+ resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning(
145
+ user_prompt=user_prompt,
146
+ chat_history=chat_history,
147
+ )
148
+ prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True)
149
+ inputs = PROCESSOR(text=prompt, images=resulting_images if resulting_images else None, return_tensors="pt")
150
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
151
+ generation_args.update(inputs)
152
+
153
+
154
+ thread = Thread(
155
+ target=model.generate,
156
+ kwargs=generation_args,
157
+ )
158
+ thread.start()
159
+
160
+ print("Start generating")
161
+ acc_text = ""
162
+ for text_token in streamer:
163
+ time.sleep(0.04)
164
+ acc_text += text_token
165
+ if acc_text.endswith("<end_of_utterance>"):
166
+ acc_text = acc_text[:-18]
167
+ yield acc_text
168
+ print("Success - generated the following text:", acc_text)
169
+ print("-----")
170
+ BOT_AVATAR = "IDEFICS_logo.png"
171
+
172
+ # Hyper-parameters for generation
173
+ max_new_tokens = gr.Slider(
174
+ minimum=8,
175
+ maximum=1024,
176
+ value=512,
177
+ step=1,
178
+ interactive=True,
179
+ label="Maximum number of new tokens to generate",
180
+ )
181
+ repetition_penalty = gr.Slider(
182
+ minimum=0.01,
183
+ maximum=5.0,
184
+ value=1.2,
185
+ step=0.01,
186
+ interactive=True,
187
+ label="Repetition penalty",
188
+ info="1.0 is equivalent to no penalty",
189
+ )
190
+ decoding_strategy = gr.Radio(
191
+ [
192
+ "Greedy",
193
+ "Top P Sampling",
194
+ ],
195
+ value="Greedy",
196
+ label="Decoding strategy",
197
+ interactive=True,
198
+ info="Higher values is equivalent to sampling more low-probability tokens.",
199
+ )
200
+ temperature = gr.Slider(
201
+ minimum=0.0,
202
+ maximum=5.0,
203
+ value=0.4,
204
+ step=0.1,
205
+ interactive=True,
206
+ label="Sampling temperature",
207
+ info="Higher values will produce more diverse outputs.",
208
+ )
209
+ top_p = gr.Slider(
210
+ minimum=0.01,
211
+ maximum=0.99,
212
+ value=0.8,
213
+ step=0.01,
214
+ interactive=True,
215
+ label="Top P",
216
+ info="Higher values is equivalent to sampling more low-probability tokens.",
217
+ )
218
+
219
+
220
+ chatbot = gr.Chatbot(
221
+ label="Idefics2",
222
+ avatar_images=[None, BOT_AVATAR],
223
+ # height=750,
224
+ )
225
+
226
+
227
+ with gr.Blocks(fill_height=True, css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img { width: auto; max-width: 30%; height: auto; max-height: 30%; }") as demo:
228
+ decoding_strategy.change(
229
+ fn=lambda selection: gr.Slider(
230
+ visible=(
231
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
232
+ )
233
+ ),
234
+ inputs=decoding_strategy,
235
+ outputs=temperature,
236
+ )
237
+ decoding_strategy.change(
238
+ fn=lambda selection: gr.Slider(
239
+ visible=(
240
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
241
+ )
242
+ ),
243
+ inputs=decoding_strategy,
244
+ outputs=repetition_penalty,
245
+ )
246
+ decoding_strategy.change(
247
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
248
+ inputs=decoding_strategy,
249
+ outputs=top_p,
250
+ )
251
+ examples = [{"text": "How many items are sold?", "files":["./example_images/docvqa_example.png"]},
252
+ {"text": "What is this UI about?", "files":["./example_images/s2w_example.png"]},
253
+ {"text": "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "files":["./example_images/travel_tips.jpg"]},
254
+ {"text": "Can you tell me a very short story based on this image?", "files":["./example_images/chicken_on_money.png"]},
255
+ {"text": "Where is this pastry from?", "files":["./example_images/baklava.png"]},
256
+ {"text": "How much percent is the order status?", "files":["./example_images/dummy_pdf.png"]},
257
+ {"text":"As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "files":["./example_images/art_critic.jpg"]}
258
+ ]
259
+ description = "Try [IDEFICS2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b), the instruction fine-tuned IDEFICS2 in this demo. 💬 IDEFICS2 is a state-of-the-art vision language model in various benchmarks. To get started, upload an image and write a text prompt or try one of the examples. You can also play with advanced generation parameters. To learn more about IDEFICS2, read [the blog](https://huggingface.co/blog/idefics2). Note that this model is not as chatty as the upcoming chatty model, and it will give shorter answers."
260
+
261
+
262
+ gr.ChatInterface(
263
+ fn=model_inference,
264
+ chatbot=chatbot,
265
+ examples=examples,
266
+ description=description,
267
+ title="Idefics2 Playground 🐶 ",
268
+ multimodal=True,
269
+ additional_inputs=[decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p],
270
+ )
271
+
272
+ demo.launch(debug=True)