from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel import base64 import io import os from PIL import Image import torch import numpy as np import logging # Existing imports from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from ultralytics import YOLO from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # main.py (YOLO loading fix) from utils import get_yolo_model import torch # Load YOLO model using official method yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cuda": yolo_model = yolo_model.cuda() else: yolo_model = yolo_model.cpu() # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) image_input.save(image_save_path) image = Image.open(image_save_path) box_overlay_ratio = image.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) print("Finish processing") parsed_content_list_str = "\n".join(parsed_content_list) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") print(f"Processing image: {image_file.filename}") print(f"Image size: {image_input.size}") response = process(image_input, box_threshold, iou_threshold) if not response.image: raise ValueError("Empty image in response") return response except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e))