Akbartus commited on
Commit
3ef72a4
·
verified ·
1 Parent(s): 7510b5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ import re
4
+ import time
5
+ from PIL import Image
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+
12
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
13
+ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
14
+ torch_dtype=torch.bfloat16,
15
+ #_attn_implementation="flash_attention_2"
16
+ ).to("cuda")
17
+
18
+ @spaces.GPU
19
+ def model_inference(
20
+ images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens,
21
+ repetition_penalty, top_p
22
+ ):
23
+ if text == "" and not images:
24
+ gr.Error("Please input a query and optionally image(s).")
25
+
26
+ if text == "" and images:
27
+ gr.Error("Please input a text query along the image(s).")
28
+
29
+ if isinstance(images, Image.Image):
30
+ images = [images]
31
+
32
+
33
+ resulting_messages = [
34
+ {
35
+ "role": "user",
36
+ "content": [{"type": "image"}] + [
37
+ {"type": "text", "text": text}
38
+ ]
39
+ }
40
+ ]
41
+
42
+ if assistant_prefix:
43
+ text = f"{assistant_prefix} {text}"
44
+
45
+
46
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
47
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
48
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
49
+
50
+ generation_args = {
51
+ "max_new_tokens": max_new_tokens,
52
+ "repetition_penalty": repetition_penalty,
53
+
54
+ }
55
+
56
+ assert decoding_strategy in [
57
+ "Greedy",
58
+ "Top P Sampling",
59
+ ]
60
+ if decoding_strategy == "Greedy":
61
+ generation_args["do_sample"] = False
62
+ elif decoding_strategy == "Top P Sampling":
63
+ generation_args["temperature"] = temperature
64
+ generation_args["do_sample"] = True
65
+ generation_args["top_p"] = top_p
66
+
67
+ generation_args.update(inputs)
68
+
69
+ # Generate
70
+ generated_ids = model.generate(**generation_args)
71
+
72
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
73
+ return generated_texts[0]
74
+
75
+
76
+ with gr.Blocks(fill_height=False) as demo:
77
+ gr.Markdown("## SmolVLM: Small yet Mighty 💫")
78
+ gr.Markdown("Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples.")
79
+ with gr.Column():
80
+ with gr.Row():
81
+ image_input = gr.Image(label="Upload your Image", type="pil")
82
+
83
+ with gr.Column():
84
+ query_input = gr.Textbox(label="Prompt")
85
+ assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.")
86
+
87
+ submit_btn = gr.Button("Submit")
88
+ output = gr.Textbox(label="Output")
89
+
90
+
91
+ with gr.Accordion(label="Advanced Generation Parameters", open=False):
92
+ examples=[
93
+ ["example_images/rococo.jpg", "What art era is this?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
94
+ ["example_images/examples_wat_arun.jpg", "I'm planning a visit to this temple, give me travel tips.", "", "Greedy", 0.4, 512, 1.2, 0.8],
95
+ ["example_images/examples_invoice.png", "What is the due date and the invoice date?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
96
+ ["example_images/s2w_example.png", "What is this UI about?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
97
+ ["example_images/examples_weather_events.png", "Where do the severe droughts happen according to this diagram?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
98
+ ]
99
+ # Hyper-parameters for generation
100
+ max_new_tokens = gr.Slider(
101
+ minimum=8,
102
+ maximum=1024,
103
+ value=512,
104
+ step=1,
105
+ interactive=True,
106
+ label="Maximum number of new tokens to generate",
107
+ )
108
+ repetition_penalty = gr.Slider(
109
+ minimum=0.01,
110
+ maximum=5.0,
111
+ value=1.2,
112
+ step=0.01,
113
+ interactive=True,
114
+ label="Repetition penalty",
115
+ info="1.0 is equivalent to no penalty",
116
+ )
117
+ temperature = gr.Slider(
118
+ minimum=0.0,
119
+ maximum=5.0,
120
+ value=0.4,
121
+ step=0.1,
122
+ interactive=True,
123
+ label="Sampling temperature",
124
+ info="Higher values will produce more diverse outputs.",
125
+ )
126
+ top_p = gr.Slider(
127
+ minimum=0.01,
128
+ maximum=0.99,
129
+ value=0.8,
130
+ step=0.01,
131
+ interactive=True,
132
+ label="Top P",
133
+ info="Higher values is equivalent to sampling more low-probability tokens.",
134
+ )
135
+ decoding_strategy = gr.Radio(
136
+ [
137
+ "Top P Sampling",
138
+ "Greedy",
139
+
140
+ ],
141
+ value="Top P Sampling",
142
+ label="Decoding strategy",
143
+ interactive=True,
144
+ info="Higher values is equivalent to sampling more low-probability tokens.",
145
+ )
146
+ decoding_strategy.change(
147
+ fn=lambda selection: gr.Slider(
148
+ visible=(
149
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
150
+ )
151
+ ),
152
+ inputs=decoding_strategy,
153
+ outputs=temperature,
154
+ )
155
+
156
+ decoding_strategy.change(
157
+ fn=lambda selection: gr.Slider(
158
+ visible=(
159
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
160
+ )
161
+ ),
162
+ inputs=decoding_strategy,
163
+ outputs=repetition_penalty,
164
+ )
165
+ decoding_strategy.change(
166
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
167
+ inputs=decoding_strategy,
168
+ outputs=top_p,
169
+ )
170
+ gr.Examples(
171
+ examples = examples,
172
+ inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature,
173
+ max_new_tokens, repetition_penalty, top_p],
174
+ outputs=output,
175
+ fn=model_inference
176
+ )
177
+
178
+
179
+ submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature,
180
+ max_new_tokens, repetition_penalty, top_p], outputs=output)
181
+
182
+
183
+ demo.launch(debug=True)