Spaces:
Running
on
T4
Running
on
T4
File size: 4,313 Bytes
a1c932a d03c47c e197961 a1c932a ea2ade6 a1c932a ea2ade6 a1c932a ea2ade6 3dc870c 712f1db ea2ade6 d03c47c ea2ade6 beccd45 ea2ade6 beccd45 ea2ade6 29f706c ea2ade6 beccd45 ea2ade6 29f706c a1c932a ea2ade6 a1c932a d03c47c a1c932a beccd45 a1c932a d03c47c a1c932a 712f1db a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a beccd45 a1c932a beccd45 712f1db d03c47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
from PIL import Image
import torch
import numpy as np
import logging
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# main.py (YOLO loading fix)
from utils import get_yolo_model
import torch
# Load YOLO model using official method
yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print("Finish processing")
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
print(f"Processing image: {image_file.filename}")
print(f"Image size: {image_input.size}")
response = process(image_input, box_threshold, iou_threshold)
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
|