from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model import torch from ultralytics import YOLO from PIL import Image from typing import Dict, Tuple, List import io import base64 config = { 'som_model_path': 'weights/icon_detect/best.pt', 'device': 'cuda', # Use 'cpu' if CUDA is unavailable 'caption_model_path': 'weights/icon_caption_florence', 'draw_bbox_config': { 'text_scale': 0.8, 'text_thickness': 2, 'text_padding': 3, 'thickness': 3, }, 'BOX_TRESHOLD': 0.05 } class Omniparser(object): def __init__(self, config: Dict): self.config = config self.som_model = get_yolo_model(model_path=config['som_model_path']).to(config['device']) # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device']) # self.caption_model_processor['model'].to(torch.float32) self.caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path=config['caption_model_path'],device=config['device']) def parse(self, image_path: str): print('Parsing image:', image_path) ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}) text, ocr_bbox = ocr_bbox_rslt draw_bbox_config = self.config['draw_bbox_config'] BOX_TRESHOLD = self.config['BOX_TRESHOLD'] dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.1) image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) # formating output parsed_content_list = '\n'.join(parsed_content_list) return image, str(parsed_content_list), str(label_coordinates) parser = Omniparser(config) image_path = 'examples/pc_1.png' # time the parser import time s = time.time() image, parsed_content_list = parser.parse(image_path) device = config['device'] print(f'Time taken for Omniparser on {device}:', time.time() - s)