from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import Optional import base64 import io from PIL import Image import torch import numpy as np import os # Existing imports import numpy as np import torch from PIL import Image import io from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) import torch # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt') # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence") from ultralytics import YOLO # if not os.path.exists("/data/icon_detect"): # os.makedirs("/data/icon_detect") try: yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"] yolo_model = yolo_model.to("cuda") except: yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"] from transformers import AutoProcessor, AutoModelForCausalLM import torch # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # Use float32 on CPU processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) try: model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=dtype, # Dynamic dtype based on device trust_remote_code=True ).to(device) except Exception as e: print(f"Error loading model: {str(e)}") # Fallback to CPU with float32 model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float32, trust_remote_code=True ).to("cpu") # Force config for DaViT vision tower if not hasattr(model.config, 'vision_config'): model.config.vision_config = {} if 'model_type' not in model.config.vision_config: model.config.vision_config['model_type'] = 'davit' caption_model_processor = {"processor": processor, "model": model} print("finish loading model!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str def process( image_input: Image.Image, box_threshold: float, iou_threshold: float ) -> ProcessResponse: image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) image_input.save(image_save_path) image = Image.open(image_save_path) box_overlay_ratio = image.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) print("finish processing") parsed_content_list_str = "\n".join(parsed_content_list) # Encode image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return ProcessResponse( image=img_str, parsed_content_list=str(parsed_content_list_str), label_coordinates=str(label_coordinates), ) @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") # Add debug logging print(f"Processing image: {image_file.filename}") print(f"Image size: {image_input.size}") response = process(image_input, box_threshold, iou_threshold) # Validate response if not response.image: raise ValueError("Empty image in response") return response except Exception as e: import traceback traceback.print_exc() # This will show full error in logs raise HTTPException(status_code=500, detail=str(e))