banao-tech commited on
Commit
d03c47c
·
verified ·
1 Parent(s): 24bf6bb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -41
main.py CHANGED
@@ -1,62 +1,55 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
- from typing import Optional
5
  import base64
6
  import io
7
- from PIL import Image
8
- import torch
9
- import numpy as np
10
  import os
11
 
12
- # Existing imports
13
- import numpy as np
14
- import torch
15
  from PIL import Image
16
- import io
 
17
 
 
18
  from utils import (
19
  check_ocr_box,
20
  get_yolo_model,
21
  get_caption_model_processor,
22
  get_som_labeled_img,
23
  )
24
- import torch
25
-
26
- # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
- # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
 
 
29
  from ultralytics import YOLO
 
30
 
31
- # if not os.path.exists("/data/icon_detect"):
32
- # os.makedirs("/data/icon_detect")
 
33
  try:
34
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
35
  yolo_model = yolo_model.to("cuda")
36
- except:
 
37
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
38
 
39
-
40
  print(f"YOLO model type: {type(yolo_model)}")
41
 
42
- from transformers import AutoProcessor, AutoModelForCausalLM
43
- import torch
44
-
45
- # Check if CUDA is available
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
- dtype = torch.float16 if device == "cuda" else torch.float32 # Use float32 on CPU
48
- processor = AutoProcessor.from_pretrained(
49
- "microsoft/Florence-2-base", trust_remote_code=True
50
- )
51
 
52
  try:
53
  model = AutoModelForCausalLM.from_pretrained(
54
  "weights/icon_caption_florence",
55
- torch_dtype=dtype, # Dynamic dtype based on device
56
  trust_remote_code=True
57
  ).to(device)
58
  except Exception as e:
59
- print(f"Error loading model: {str(e)}")
60
  # Fallback to CPU with float32
61
  model = AutoModelForCausalLM.from_pretrained(
62
  "weights/icon_caption_florence",
@@ -64,32 +57,37 @@ except Exception as e:
64
  trust_remote_code=True
65
  ).to("cpu")
66
 
67
- # Force config for DaViT vision tower
68
  if not hasattr(model.config, 'vision_config'):
69
  model.config.vision_config = {}
70
  if 'model_type' not in model.config.vision_config:
71
  model.config.vision_config['model_type'] = 'davit'
72
 
73
  caption_model_processor = {"processor": processor, "model": model}
74
- print("finish loading model!!!")
75
 
 
 
 
76
  app = FastAPI()
77
 
78
-
79
  class ProcessResponse(BaseModel):
80
  image: str # Base64 encoded image
81
  parsed_content_list: str
82
  label_coordinates: str
83
 
84
-
85
- def process(
86
- image_input: Image.Image, box_threshold: float, iou_threshold: float
87
- ) -> ProcessResponse:
 
88
  image_save_path = "imgs/saved_image_demo.png"
89
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
90
  image_input.save(image_save_path)
 
 
91
  image = Image.open(image_save_path)
92
- box_overlay_ratio = image.size[0] / 3200
93
  draw_bbox_config = {
94
  "text_scale": 0.8 * box_overlay_ratio,
95
  "text_thickness": max(int(2 * box_overlay_ratio), 1),
@@ -97,6 +95,7 @@ def process(
97
  "thickness": max(int(3 * box_overlay_ratio), 1),
98
  }
99
 
 
100
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
101
  image_save_path,
102
  display_img=False,
@@ -106,6 +105,8 @@ def process(
106
  use_paddleocr=True,
107
  )
108
  text, ocr_bbox = ocr_bbox_rslt
 
 
109
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
110
  image_save_path,
111
  yolo_model,
@@ -117,22 +118,26 @@ def process(
117
  ocr_text=text,
118
  iou_threshold=iou_threshold,
119
  )
 
 
120
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
121
- print("finish processing")
122
  parsed_content_list_str = "\n".join(parsed_content_list)
123
 
124
- # Encode image to base64
125
  buffered = io.BytesIO()
126
  image.save(buffered, format="PNG")
127
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
128
 
129
  return ProcessResponse(
130
  image=img_str,
131
- parsed_content_list=str(parsed_content_list_str),
132
  label_coordinates=str(label_coordinates),
133
  )
134
 
135
-
 
 
136
  @app.post("/process_image", response_model=ProcessResponse)
137
  async def process_image(
138
  image_file: UploadFile = File(...),
@@ -143,7 +148,7 @@ async def process_image(
143
  contents = await image_file.read()
144
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
145
 
146
- # Add debug logging
147
  print(f"Processing image: {image_file.filename}")
148
  print(f"Image size: {image_input.size}")
149
 
@@ -157,5 +162,5 @@ async def process_image(
157
 
158
  except Exception as e:
159
  import traceback
160
- traceback.print_exc() # This will show full error in logs
161
- raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
 
4
  import base64
5
  import io
 
 
 
6
  import os
7
 
 
 
 
8
  from PIL import Image
9
+ import torch
10
+ import numpy as np
11
 
12
+ # Import your custom utility functions
13
  from utils import (
14
  check_ocr_box,
15
  get_yolo_model,
16
  get_caption_model_processor,
17
  get_som_labeled_img,
18
  )
 
 
 
 
19
 
20
+ # Import YOLO from ultralytics and transformers for captioning
21
  from ultralytics import YOLO
22
+ from transformers import AutoProcessor, AutoModelForCausalLM
23
 
24
+ # ---------------------------------------------------------------------------
25
+ # Load the YOLO model
26
+ # ---------------------------------------------------------------------------
27
  try:
28
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
29
  yolo_model = yolo_model.to("cuda")
30
+ except Exception as e:
31
+ print("Error loading YOLO model on CUDA:", e)
32
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
33
 
 
34
  print(f"YOLO model type: {type(yolo_model)}")
35
 
36
+ # ---------------------------------------------------------------------------
37
+ # Load the captioning model (Florence-2)
38
+ # ---------------------------------------------------------------------------
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ dtype = torch.float16 if device == "cuda" else torch.float32
41
+
42
+ # Load the processor for the Florence-2 model
43
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
44
 
45
  try:
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "weights/icon_caption_florence",
48
+ torch_dtype=dtype,
49
  trust_remote_code=True
50
  ).to(device)
51
  except Exception as e:
52
+ print(f"Error loading caption model: {str(e)}")
53
  # Fallback to CPU with float32
54
  model = AutoModelForCausalLM.from_pretrained(
55
  "weights/icon_caption_florence",
 
57
  trust_remote_code=True
58
  ).to("cpu")
59
 
60
+ # Force configuration for DaViT vision tower if missing
61
  if not hasattr(model.config, 'vision_config'):
62
  model.config.vision_config = {}
63
  if 'model_type' not in model.config.vision_config:
64
  model.config.vision_config['model_type'] = 'davit'
65
 
66
  caption_model_processor = {"processor": processor, "model": model}
67
+ print("Finish loading caption model!")
68
 
69
+ # ---------------------------------------------------------------------------
70
+ # Create FastAPI application and response model
71
+ # ---------------------------------------------------------------------------
72
  app = FastAPI()
73
 
 
74
  class ProcessResponse(BaseModel):
75
  image: str # Base64 encoded image
76
  parsed_content_list: str
77
  label_coordinates: str
78
 
79
+ # ---------------------------------------------------------------------------
80
+ # Main processing function
81
+ # ---------------------------------------------------------------------------
82
+ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
83
+ # Save the input image temporarily
84
  image_save_path = "imgs/saved_image_demo.png"
85
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
86
  image_input.save(image_save_path)
87
+
88
+ # Open the saved image for processing
89
  image = Image.open(image_save_path)
90
+ box_overlay_ratio = image.size[0] / 3200 # adjust scaling factor as needed
91
  draw_bbox_config = {
92
  "text_scale": 0.8 * box_overlay_ratio,
93
  "text_thickness": max(int(2 * box_overlay_ratio), 1),
 
95
  "thickness": max(int(3 * box_overlay_ratio), 1),
96
  }
97
 
98
+ # Run OCR to get text and OCR bounding boxes
99
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
100
  image_save_path,
101
  display_img=False,
 
105
  use_paddleocr=True,
106
  )
107
  text, ocr_bbox = ocr_bbox_rslt
108
+
109
+ # Run YOLO and semantic processing to get the labeled image and bounding boxes
110
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
111
  image_save_path,
112
  yolo_model,
 
118
  ocr_text=text,
119
  iou_threshold=iou_threshold,
120
  )
121
+
122
+ # Decode the base64-encoded image output from get_som_labeled_img
123
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
124
+ print("Finish processing")
125
  parsed_content_list_str = "\n".join(parsed_content_list)
126
 
127
+ # Encode final image to base64 string for response
128
  buffered = io.BytesIO()
129
  image.save(buffered, format="PNG")
130
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
131
 
132
  return ProcessResponse(
133
  image=img_str,
134
+ parsed_content_list=parsed_content_list_str,
135
  label_coordinates=str(label_coordinates),
136
  )
137
 
138
+ # ---------------------------------------------------------------------------
139
+ # FastAPI endpoint for image processing
140
+ # ---------------------------------------------------------------------------
141
  @app.post("/process_image", response_model=ProcessResponse)
142
  async def process_image(
143
  image_file: UploadFile = File(...),
 
148
  contents = await image_file.read()
149
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
150
 
151
+ # Debug logging for file information
152
  print(f"Processing image: {image_file.filename}")
153
  print(f"Image size: {image_input.size}")
154
 
 
162
 
163
  except Exception as e:
164
  import traceback
165
+ traceback.print_exc() # Print full traceback for debugging
166
+ raise HTTPException(status_code=500, detail=str(e))