demo / api.py
zbing's picture
Upload folder using huggingface_hub
9e82853 verified
raw
history blame
2.89 kB
import argparse
from flask import Flask, request, jsonify
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
import threading
app = Flask(__name__)
# Parse command line arguments
parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.')
parser.add_argument('--model-path', type=str, default="models/Florence-2-base-ft", help='Path to the pretrained model')
parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"')
args = parser.parse_args()
# Determine the device
if args.device == 'auto':
device = "cuda:0" if torch.cuda.is_available() else "cpu"
elif args.device == 'gpu':
if torch.cuda.is_available():
device = "cuda:0"
else:
raise ValueError("GPU option specified but no GPU is available.")
else:
device = "cpu"
torch_dtype = torch.float16 if device.startswith("cuda") else torch.float32
# Initialize the model and processor
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model
def predict_image(image, task: str = "<OD>", prompt: str = None):
prompt = task + " " + prompt if prompt else task
print(f"Prompt: {prompt}")
with lock:
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
return parsed_answer
@app.route('/predict', methods=['POST'])
def predict():
if request.is_json:
data = request.get_json()
if 'image' not in data:
return jsonify({'error': 'No image found in JSON'}), 400
image_data = base64.b64decode(data['image'].split(",")[1])
image = Image.open(BytesIO(image_data))
else:
return jsonify({'error': 'No image file or JSON payload'}), 400
task = data.get('task', "<OD>")
prompt = data.get('prompt', None)
prediction = predict_image(image, task, prompt)
msgid = data.get('msgid', None)
response = {
'msgid': msgid,
'prediction': prediction
}
return jsonify(response)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True) # Enable multi-threading