hysts HF staff commited on
Commit
2a6a910
1 Parent(s): 3b85b9a

Use transformers

Browse files
Files changed (7) hide show
  1. .pre-commit-config.yaml +37 -0
  2. .style.yapf +5 -0
  3. README.md +2 -2
  4. app.py +249 -262
  5. requirements.txt +8 -0
  6. style.css +3 -0
  7. utils.py +0 -27
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: BLIP2
3
  emoji: 🌖
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.17.0
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
 
1
  ---
2
+ title: BLIP2 with transformers
3
  emoji: 🌖
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
app.py CHANGED
@@ -1,282 +1,269 @@
1
- from io import BytesIO
2
 
3
- import string
4
- import gradio as gr
5
- import requests
6
- from utils import Endpoint, get_token
7
-
8
-
9
- def encode_image(image):
10
- buffered = BytesIO()
11
- image.save(buffered, format="JPEG")
12
- buffered.seek(0)
13
-
14
- return buffered
15
-
16
-
17
- def query_chat_api(
18
- image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
19
- ):
20
-
21
- url = endpoint.url
22
- url = url + "/api/generate"
23
-
24
- headers = {
25
- "User-Agent": "BLIP-2 HuggingFace Space",
26
- "Auth-Token": get_token(),
27
- }
28
-
29
- data = {
30
- "prompt": prompt,
31
- "use_nucleus_sampling": decoding_method == "Nucleus sampling",
32
- "temperature": temperature,
33
- "length_penalty": len_penalty,
34
- "repetition_penalty": repetition_penalty,
35
- }
36
-
37
- image = encode_image(image)
38
- files = {"image": image}
39
-
40
- response = requests.post(url, data=data, files=files, headers=headers)
41
 
42
- if response.status_code == 200:
43
- return response.json()
44
- else:
45
- return "Error: " + response.text
46
-
47
-
48
- def query_caption_api(
49
- image, decoding_method, temperature, len_penalty, repetition_penalty
50
- ):
51
-
52
- url = endpoint.url
53
- url = url + "/api/caption"
54
-
55
- headers = {
56
- "User-Agent": "BLIP-2 HuggingFace Space",
57
- "Auth-Token": get_token(),
58
- }
59
 
60
- data = {
61
- "use_nucleus_sampling": decoding_method == "Nucleus sampling",
62
- "temperature": temperature,
63
- "length_penalty": len_penalty,
64
- "repetition_penalty": repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  }
66
-
67
- image = encode_image(image)
68
- files = {"image": image}
69
-
70
- response = requests.post(url, data=data, files=files, headers=headers)
71
-
72
- if response.status_code == 200:
73
- return response.json()
74
- else:
75
- return "Error: " + response.text
76
-
77
-
78
- def postprocess_output(output):
79
- # if last character is not a punctuation, add a full stop
80
- if not output[0][-1] in string.punctuation:
81
- output[0] += "."
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return output
84
 
85
 
86
- def inference_chat(
87
- image,
88
- text_input,
89
- decoding_method,
90
- temperature,
91
- length_penalty,
92
- repetition_penalty,
93
- history=[],
94
- ):
95
- text_input = text_input
96
- history.append(text_input)
97
-
98
- prompt = " ".join(history)
99
-
100
- output = query_chat_api(
101
- image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
 
 
 
 
 
 
 
 
102
  )
103
  output = postprocess_output(output)
104
- history += output
105
-
106
- chat = [
107
- (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
108
- ] # convert to tuples of list
109
-
110
- return {chatbot: chat, state: history}
111
-
112
-
113
- def inference_caption(
114
- image,
115
- decoding_method,
116
- temperature,
117
- length_penalty,
118
- repetition_penalty,
119
- ):
120
- output = query_caption_api(
121
- image, decoding_method, temperature, length_penalty, repetition_penalty
122
- )
123
-
124
- return output[0]
125
 
 
 
 
126
 
127
- title = """<h1 align="center">BLIP-2</h1>"""
128
- description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them.
129
- <br> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected."""
130
- article = """<strong>Paper</strong>: <a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>
131
- <br> <strong>Code</strong>: BLIP2 is now integrated into GitHub repo: <a href='https://github.com/salesforce/LAVIS' target='_blank'>LAVIS: a One-stop Library for Language and Vision</a>
132
- <br> <strong>🤗 `transformers` integration</strong>: You can now use `transformers` to use our BLIP-2 models! Check out the <a href='https://huggingface.co/docs/transformers/main/en/model_doc/blip-2' target='_blank'> official docs </a>
133
- <p> <strong>Project Page</strong>: <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'> BLIP2 on LAVIS</a>
134
- <br> <strong>Description</strong>: Captioning results from <strong>BLIP2_OPT_6.7B</strong>. Chat results from <strong>BLIP2_FlanT5xxl</strong>.
135
- """
136
-
137
- endpoint = Endpoint()
138
 
139
  examples = [
140
- ["house.png", "How could someone get out of the house?"],
141
- ["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"],
142
- ["pizza.jpg", "What are steps to cook it?"],
143
- ["sunset.jpg", "Here is a romantic message going along the photo:"],
144
- ["forbidden_city.webp", "In what dynasties was this place built?"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ]
146
 
147
- with gr.Blocks(
148
- css="""
149
- .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
150
- #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
151
- """
152
- ) as iface:
153
- state = gr.State([])
154
-
155
- gr.Markdown(title)
156
- gr.Markdown(description)
157
- gr.Markdown(article)
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  with gr.Row():
160
- with gr.Column(scale=1):
161
- image_input = gr.Image(type="pil")
162
-
163
- # with gr.Row():
164
- sampling = gr.Radio(
165
- choices=["Beam search", "Nucleus sampling"],
166
- value="Beam search",
167
- label="Text Decoding Method",
168
- interactive=True,
169
- )
170
-
171
- temperature = gr.Slider(
172
- minimum=0.5,
173
- maximum=1.0,
174
- value=1.0,
175
- step=0.1,
176
- interactive=True,
177
- label="Temperature (used with nucleus sampling)",
178
- )
179
-
180
- len_penalty = gr.Slider(
181
- minimum=-1.0,
182
- maximum=2.0,
183
- value=1.0,
184
- step=0.2,
185
- interactive=True,
186
- label="Length Penalty (set to larger for longer sequence, used with beam search)",
187
- )
188
-
189
- rep_penalty = gr.Slider(
190
- minimum=1.0,
191
- maximum=5.0,
192
- value=1.5,
193
- step=0.5,
194
- interactive=True,
195
- label="Repeat Penalty (larger value prevents repetition)",
196
- )
197
-
198
- with gr.Column(scale=1.8):
199
-
200
- with gr.Column():
201
- caption_output = gr.Textbox(lines=1, label="Caption Output")
202
- caption_button = gr.Button(
203
- value="Caption it!", interactive=True, variant="primary"
204
- )
205
- caption_button.click(
206
- inference_caption,
207
- [
208
- image_input,
209
- sampling,
210
- temperature,
211
- len_penalty,
212
- rep_penalty,
213
- ],
214
- [caption_output],
215
- )
216
-
217
- gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""")
218
- with gr.Row():
219
- with gr.Column(
220
- scale=1.5,
221
- ):
222
- chatbot = gr.Chatbot(
223
- label="Chat Output (from FlanT5)",
224
- )
225
-
226
- # with gr.Row():
227
- with gr.Column(scale=1):
228
- chat_input = gr.Textbox(lines=1, label="Chat Input")
229
- chat_input.submit(
230
- inference_chat,
231
- [
232
- image_input,
233
- chat_input,
234
- sampling,
235
- temperature,
236
- len_penalty,
237
- rep_penalty,
238
- state,
239
- ],
240
- [chatbot, state],
241
- )
242
-
243
- with gr.Row():
244
- clear_button = gr.Button(value="Clear", interactive=True)
245
- clear_button.click(
246
- lambda: ("", [], []),
247
- [],
248
- [chat_input, chatbot, state],
249
- queue=False,
250
- )
251
-
252
- submit_button = gr.Button(
253
- value="Submit", interactive=True, variant="primary"
254
- )
255
- submit_button.click(
256
- inference_chat,
257
- [
258
- image_input,
259
- chat_input,
260
- sampling,
261
- temperature,
262
- len_penalty,
263
- rep_penalty,
264
- state,
265
- ],
266
- [chatbot, state],
267
- )
268
 
269
- image_input.change(
270
- lambda: ("", "", []),
271
- [],
272
- [chatbot, caption_output, state],
273
- queue=False,
274
- )
 
 
 
 
 
 
275
 
276
- examples = gr.Examples(
277
- examples=examples,
278
- inputs=[image_input, chat_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  )
280
 
281
- iface.queue(concurrency_count=1, api_open=False, max_size=10)
282
- iface.launch(enable_queue=True)
 
1
+ #!/usr/bin/env python
2
 
3
+ from __future__ import annotations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ import string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ import gradio as gr
8
+ import PIL.Image
9
+ import torch
10
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
11
+
12
+ DESCRIPTION = '# BLIP-2'
13
+
14
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
+
16
+ MODEL_ID_OPT_6_7B = 'Salesforce/blip2-opt-6.7b'
17
+ MODEL_ID_FLAN_T5_XXL = 'Salesforce/blip2-flan-t5-xxl'
18
+ model_dict = {
19
+ MODEL_ID_OPT_6_7B: {
20
+ 'processor':
21
+ AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
22
+ 'model':
23
+ Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
24
+ device_map='auto',
25
+ load_in_8bit=True),
26
+ },
27
+ MODEL_ID_FLAN_T5_XXL: {
28
+ 'processor':
29
+ AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
30
+ 'model':
31
+ Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL,
32
+ device_map='auto',
33
+ load_in_8bit=True),
34
  }
35
+ }
36
+
37
+
38
+ def generate_caption(model_id: str, image: PIL.Image.Image,
39
+ decoding_method: str, temperature: float,
40
+ length_penalty: float, repetition_penalty: float) -> str:
41
+ model_info = model_dict[model_id]
42
+ processor = model_info['processor']
43
+ model = model_info['model']
44
+
45
+ inputs = processor(images=image,
46
+ return_tensors='pt').to(device, torch.float16)
47
+ generated_ids = model.generate(
48
+ pixel_values=inputs.pixel_values,
49
+ do_sample=decoding_method == 'Nucleus sampling',
50
+ temperature=temperature,
51
+ length_penalty=length_penalty,
52
+ repetition_penalty=repetition_penalty,
53
+ max_length=50)
54
+ result = processor.batch_decode(generated_ids,
55
+ skip_special_tokens=True)[0].strip()
56
+ return result
57
+
58
+
59
+ def answer_question(model_id: str, image: PIL.Image.Image, text: str,
60
+ decoding_method: str, temperature: float,
61
+ length_penalty: float, repetition_penalty: float) -> str:
62
+ model_info = model_dict[model_id]
63
+ processor = model_info['processor']
64
+ model = model_info['model']
65
+
66
+ inputs = processor(images=image, text=text,
67
+ return_tensors='pt').to(device, torch.float16)
68
+ generated_ids = model.generate(**inputs,
69
+ do_sample=decoding_method ==
70
+ 'Nucleus sampling',
71
+ temperature=temperature,
72
+ length_penalty=length_penalty,
73
+ repetition_penalty=repetition_penalty)
74
+ result = processor.batch_decode(generated_ids,
75
+ skip_special_tokens=True)[0].strip()
76
+ return result
77
+
78
+
79
+ def postprocess_output(output: str) -> str:
80
+ if output and not output[-1] in string.punctuation:
81
+ output += '.'
82
  return output
83
 
84
 
85
+ def chat(
86
+ model_id: str,
87
+ image: PIL.Image.Image,
88
+ text: str,
89
+ decoding_method: str,
90
+ temperature: float,
91
+ length_penalty: float,
92
+ repetition_penalty: float,
93
+ history_orig: list[str] = [],
94
+ history_qa: list[str] = [],
95
+ ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]]]:
96
+ history_orig.append(text)
97
+ text_qa = f'Question: {text} Answer:'
98
+ history_qa.append(text_qa)
99
+ prompt = ' '.join(history_qa)
100
+
101
+ output = answer_question(
102
+ model_id,
103
+ image,
104
+ prompt,
105
+ decoding_method,
106
+ temperature,
107
+ length_penalty,
108
+ repetition_penalty,
109
  )
110
  output = postprocess_output(output)
111
+ history_orig.append(output)
112
+ history_qa.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ chat_val = list(zip(history_orig[0::2], history_orig[1::2]))
115
+ return gr.update(value=chat_val), gr.update(value=history_orig), gr.update(
116
+ value=history_qa)
117
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  examples = [
120
+ [
121
+ 'house.png',
122
+ 'How could someone get out of the house?',
123
+ ],
124
+ [
125
+ 'flower.jpg',
126
+ 'What is this flower and where is it\'s origin?',
127
+ ],
128
+ [
129
+ 'pizza.jpg',
130
+ 'What are steps to cook it?',
131
+ ],
132
+ [
133
+ 'sunset.jpg',
134
+ 'Here is a romantic message going along the photo:',
135
+ ],
136
+ [
137
+ 'forbidden_city.webp',
138
+ 'In what dynasties was this place built?',
139
+ ],
140
  ]
141
 
142
+ with gr.Blocks(css='style.css') as demo:
143
+ gr.Markdown(DESCRIPTION)
144
+
145
+ image = gr.Image(type='pil')
146
+ with gr.Accordion(label='Advanced settings', open=False):
147
+ with gr.Row():
148
+ model_id_caption = gr.Dropdown(
149
+ label='Model ID for image captioning',
150
+ choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
151
+ value=MODEL_ID_OPT_6_7B)
152
+ model_id_chat = gr.Dropdown(
153
+ label='Model ID for VQA',
154
+ choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
155
+ value=MODEL_ID_FLAN_T5_XXL)
156
+ sampling_method = gr.Radio(
157
+ label='Text Decoding Method',
158
+ choices=['Beam search', 'Nucleus sampling'],
159
+ value='Beam search',
160
+ )
161
+ temperature = gr.Slider(
162
+ label='Temperature (used with nucleus sampling)',
163
+ minimum=0.5,
164
+ maximum=1.0,
165
+ value=1.0,
166
+ step=0.1,
167
+ )
168
+ length_penalty = gr.Slider(
169
+ label=
170
+ 'Length Penalty (set to larger for longer sequence, used with beam search)',
171
+ minimum=-1.0,
172
+ maximum=2.0,
173
+ value=1.0,
174
+ step=0.2,
175
+ )
176
+ rep_penalty = gr.Slider(
177
+ label='Repeat Penalty (larger value prevents repetition)',
178
+ minimum=1.0,
179
+ maximum=5.0,
180
+ value=1.5,
181
+ step=0.5,
182
+ )
183
  with gr.Row():
184
+ with gr.Column():
185
+ with gr.Box():
186
+ gr.Markdown('Image Captioning')
187
+ caption_button = gr.Button(value='Caption it!')
188
+ caption_output = gr.Textbox(label='Caption Output')
189
+ with gr.Column():
190
+ with gr.Box():
191
+ gr.Markdown('VQA Chat')
192
+ vqa_input = gr.Text(label='Chat Input', max_lines=1)
193
+ with gr.Row():
194
+ clear_chat_button = gr.Button(value='Clear')
195
+ chat_button = gr.Button(value='Submit')
196
+ chatbot = gr.Chatbot(label='Chat Output')
197
+ history_orig = gr.State(value=[])
198
+ history_qa = gr.State(value=[])
199
+
200
+ gr.Examples(
201
+ examples=examples,
202
+ inputs=[
203
+ image,
204
+ vqa_input,
205
+ ],
206
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ caption_button.click(
209
+ fn=generate_caption,
210
+ inputs=[
211
+ model_id_caption,
212
+ image,
213
+ sampling_method,
214
+ temperature,
215
+ length_penalty,
216
+ rep_penalty,
217
+ ],
218
+ outputs=caption_output,
219
+ )
220
 
221
+ chat_inputs = [
222
+ model_id_chat,
223
+ image,
224
+ vqa_input,
225
+ sampling_method,
226
+ temperature,
227
+ length_penalty,
228
+ rep_penalty,
229
+ history_orig,
230
+ ]
231
+ chat_outputs = [
232
+ chatbot,
233
+ history_orig,
234
+ history_qa,
235
+ ]
236
+ vqa_input.submit(
237
+ fn=chat,
238
+ inputs=chat_inputs,
239
+ outputs=chat_outputs,
240
+ )
241
+ chat_button.click(
242
+ fn=chat,
243
+ inputs=chat_inputs,
244
+ outputs=chat_outputs,
245
+ )
246
+ clear_chat_button.click(
247
+ fn=lambda: ('', [], [], []),
248
+ inputs=None,
249
+ outputs=[
250
+ vqa_input,
251
+ chatbot,
252
+ history_orig,
253
+ history_qa,
254
+ ],
255
+ queue=False,
256
+ )
257
+ image.change(
258
+ fn=lambda: ('', '', [], []),
259
+ inputs=None,
260
+ outputs=[
261
+ chatbot,
262
+ caption_output,
263
+ history_orig,
264
+ history_qa,
265
+ ],
266
+ queue=False,
267
  )
268
 
269
+ demo.queue(max_size=10).launch()
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.16.0
2
+ bitsandbytes==0.37.0
3
+ git+https://github.com/huggingface/transformers@c836f77
4
+ gradio==3.18.0
5
+ huggingface-hub==0.12.0
6
+ Pillow==9.4.0
7
+ torch==1.13.1
8
+ torchvision==0.14.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils.py DELETED
@@ -1,27 +0,0 @@
1
- import os
2
-
3
-
4
- class Endpoint:
5
- def __init__(self):
6
- self._url = None
7
-
8
- @property
9
- def url(self):
10
- if self._url is None:
11
- self._url = self.get_url()
12
-
13
- return self._url
14
-
15
- def get_url(self):
16
- endpoint = os.environ.get("endpoint")
17
-
18
- return endpoint
19
-
20
-
21
- def get_token():
22
- token = os.environ.get("auth_token")
23
-
24
- if token is None:
25
- raise ValueError("auth-token not found in environment variables")
26
-
27
- return token