Files changed (1) hide show
  1. README.md +83 -0
README.md CHANGED
@@ -205,3 +205,86 @@ draw_entity_boxes_on_image(image, entities, show=True)
205
  Here is the annotated image:
206
 
207
  <a href="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" target="_blank"><img src="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" width="500"></a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  Here is the annotated image:
206
 
207
  <a href="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" target="_blank"><img src="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" width="500"></a>
208
+
209
+
210
+ ## Running the Flask Server
211
+ _flask_kosmos2.py_ shows the implementation of a Flask server for the model.
212
+ It allowes the model to be approached as a REST API.
213
+
214
+ After starting the server. You can send a POST request to `http://localhost:8005/process_prompt` with the following form data:
215
+ - `prompt`: For example `<grounding> an image of`
216
+ - `image`: The image file as binary data
217
+
218
+ This in turn will produce a reply with the following JSON format:
219
+ - `message`: The Kosmos-2 generated text
220
+ - `entities`: The extracted entities
221
+
222
+ An easy way to test this is through an application like Postman. Make sure the image field is set to `File`.
223
+
224
+ ```python
225
+
226
+ from PIL import Image
227
+ from transformers import AutoProcessor, AutoModelForVision2Seq
228
+ from flask import Flask, request, jsonify
229
+ import json
230
+
231
+ app = Flask(__name__)
232
+
233
+ model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
234
+ processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
235
+
236
+
237
+ @app.route('/process_prompt', methods=['POST'])
238
+ def process_prompt():
239
+ try:
240
+ # Get the uploaded image data from the POST request
241
+ uploaded_file = request.files['image']
242
+ prompt = request.form.get('prompt')
243
+ image = Image.open(uploaded_file.stream)
244
+
245
+ print(image.size)
246
+
247
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
248
+
249
+ generated_ids = model.generate(
250
+ pixel_values=inputs["pixel_values"],
251
+ input_ids=inputs["input_ids"][:, :-1],
252
+ attention_mask=inputs["attention_mask"][:, :-1],
253
+ img_features=None,
254
+ img_attn_mask=inputs["img_attn_mask"][:, :-1],
255
+ use_cache=True,
256
+ max_new_tokens=64,
257
+ )
258
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
259
+
260
+ # By default, the generated text is cleanup and the entities are extracted.
261
+ processed_text, entities = processor.post_process_generation(generated_text)
262
+ parsed_entities = entities_to_json(entities)
263
+ print(generated_text)
264
+ print(processed_text)
265
+ return jsonify({"message": processed_text, 'entities': parsed_entities})
266
+ except Exception as e:
267
+ return jsonify({"error": str(e)})
268
+
269
+
270
+ def entities_to_json(entities):
271
+ result = []
272
+ for e in entities:
273
+ label = e[0]
274
+ box_coords = e[1]
275
+ box_size = e[2][0]
276
+ entity_result = {
277
+ "label": label,
278
+ "boundingBoxPosition": {"x": box_coords[0], "y": box_coords[1]},
279
+ "boundingBox": {"x_min": box_size[0], "y_min": box_size[1], "x_max": box_size[2], "y_max": box_size[3]}
280
+ }
281
+ print(entity_result)
282
+ result.append(entity_result)
283
+
284
+ return result
285
+
286
+
287
+ if __name__ == '__main__':
288
+ app.run(host='localhost', port=8005)
289
+
290
+ ```