import argparse from flask import Flask, request, jsonify from PIL import Image from io import BytesIO import base64 import torch from transformers import AutoProcessor, AutoModelForCausalLM import threading app = Flask(__name__) # Parse command line arguments parser = argparse.ArgumentParser(description='Start the Flask server with specified model and device.') parser.add_argument('--model-path', type=str, default="models/Florence-2-base-ft", help='Path to the pretrained model') parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='auto', help='Device to use: "cpu", "gpu", or "auto"') args = parser.parse_args() # Determine the device if args.device == 'auto': device = "cuda:0" if torch.cuda.is_available() else "cpu" elif args.device == 'gpu': if torch.cuda.is_available(): device = "cuda:0" else: raise ValueError("GPU option specified but no GPU is available.") else: device = "cpu" torch_dtype = torch.float16 if device.startswith("cuda") else torch.float32 from unittest.mock import patch from transformers.dynamic_module_utils import get_imports import os def fixed_get_imports(filename: str | os.PathLike) -> list[str]: if not str(filename).endswith(""): return get_imports(filename) imports = get_imports(filename) imports.remove("flash_attn") return imports # Initialize the model and processor with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement model = AutoModelForCausalLM.from_pretrained(args.model_path, attn_implementation="sdpa", torch_dtype=torch_dtype,trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True) lock = threading.Lock() # Use a lock to ensure thread safety when accessing the model def predict_image(image, task: str = "", prompt: str = None): prompt = task + " " + prompt if prompt else task print(f"Prompt: {prompt}") with lock: inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, do_sample=False, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height)) return parsed_answer @app.route('/predict', methods=['POST']) def predict(): if request.is_json: data = request.get_json() if 'image' not in data: return jsonify({'error': 'No image found in JSON'}), 400 image_data = base64.b64decode(data['image'].split(",")[1]) image = else: return jsonify({'error': 'No image file or JSON payload'}), 400 task = data.get('task', "") prompt = data.get('prompt', None) prediction = predict_image(image, task, prompt) msgid = data.get('msgid', None) response = { 'msgid': msgid, 'prediction': prediction } return jsonify(response) if __name__ == '__main__':'', port=5000, threaded=True) # Enable multi-threading