Upload README.md
#5
by
JGKaaij
- opened
README.md
CHANGED
@@ -205,3 +205,86 @@ draw_entity_boxes_on_image(image, entities, show=True)
|
|
205 |
Here is the annotated image:
|
206 |
|
207 |
<a href="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" target="_blank"><img src="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" width="500"></a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
Here is the annotated image:
|
206 |
|
207 |
<a href="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" target="_blank"><img src="https://huggingface.co/ydshieh/kosmos-2-patch14-224/resolve/main/annotated_snowman.jpg" width="500"></a>
|
208 |
+
|
209 |
+
|
210 |
+
## Running the Flask Server
|
211 |
+
_flask_kosmos2.py_ shows the implementation of a Flask server for the model.
|
212 |
+
It allowes the model to be approached as a REST API.
|
213 |
+
|
214 |
+
After starting the server. You can send a POST request to `http://localhost:8005/process_prompt` with the following form data:
|
215 |
+
- `prompt`: For example `<grounding> an image of`
|
216 |
+
- `image`: The image file as binary data
|
217 |
+
|
218 |
+
This in turn will produce a reply with the following JSON format:
|
219 |
+
- `message`: The Kosmos-2 generated text
|
220 |
+
- `entities`: The extracted entities
|
221 |
+
|
222 |
+
An easy way to test this is through an application like Postman. Make sure the image field is set to `File`.
|
223 |
+
|
224 |
+
```python
|
225 |
+
|
226 |
+
from PIL import Image
|
227 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
228 |
+
from flask import Flask, request, jsonify
|
229 |
+
import json
|
230 |
+
|
231 |
+
app = Flask(__name__)
|
232 |
+
|
233 |
+
model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
|
234 |
+
processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
|
235 |
+
|
236 |
+
|
237 |
+
@app.route('/process_prompt', methods=['POST'])
|
238 |
+
def process_prompt():
|
239 |
+
try:
|
240 |
+
# Get the uploaded image data from the POST request
|
241 |
+
uploaded_file = request.files['image']
|
242 |
+
prompt = request.form.get('prompt')
|
243 |
+
image = Image.open(uploaded_file.stream)
|
244 |
+
|
245 |
+
print(image.size)
|
246 |
+
|
247 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt")
|
248 |
+
|
249 |
+
generated_ids = model.generate(
|
250 |
+
pixel_values=inputs["pixel_values"],
|
251 |
+
input_ids=inputs["input_ids"][:, :-1],
|
252 |
+
attention_mask=inputs["attention_mask"][:, :-1],
|
253 |
+
img_features=None,
|
254 |
+
img_attn_mask=inputs["img_attn_mask"][:, :-1],
|
255 |
+
use_cache=True,
|
256 |
+
max_new_tokens=64,
|
257 |
+
)
|
258 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
259 |
+
|
260 |
+
# By default, the generated text is cleanup and the entities are extracted.
|
261 |
+
processed_text, entities = processor.post_process_generation(generated_text)
|
262 |
+
parsed_entities = entities_to_json(entities)
|
263 |
+
print(generated_text)
|
264 |
+
print(processed_text)
|
265 |
+
return jsonify({"message": processed_text, 'entities': parsed_entities})
|
266 |
+
except Exception as e:
|
267 |
+
return jsonify({"error": str(e)})
|
268 |
+
|
269 |
+
|
270 |
+
def entities_to_json(entities):
|
271 |
+
result = []
|
272 |
+
for e in entities:
|
273 |
+
label = e[0]
|
274 |
+
box_coords = e[1]
|
275 |
+
box_size = e[2][0]
|
276 |
+
entity_result = {
|
277 |
+
"label": label,
|
278 |
+
"boundingBoxPosition": {"x": box_coords[0], "y": box_coords[1]},
|
279 |
+
"boundingBox": {"x_min": box_size[0], "y_min": box_size[1], "x_max": box_size[2], "y_max": box_size[3]}
|
280 |
+
}
|
281 |
+
print(entity_result)
|
282 |
+
result.append(entity_result)
|
283 |
+
|
284 |
+
return result
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == '__main__':
|
288 |
+
app.run(host='localhost', port=8005)
|
289 |
+
|
290 |
+
```
|