omniapi / main.py
banao-tech's picture
Update main.py
e197961 verified
raw
history blame
4.31 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
from PIL import Image
import torch
import numpy as np
import logging
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# main.py (YOLO loading fix)
from utils import get_yolo_model
import torch
# Load YOLO model using official method
yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print("Finish processing")
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
print(f"Processing image: {image_file.filename}")
print(f"Image size: {image_input.size}")
response = process(image_input, box_threshold, iou_threshold)
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))