banao-tech commited on
Commit
ea2ade6
·
verified ·
1 Parent(s): d3c30f4

Rename app.py to main.py

Browse files
Files changed (1) hide show
  1. app.py → main.py +30 -24
app.py → main.py RENAMED
@@ -8,49 +8,55 @@ from PIL import Image
8
  import torch
9
  import numpy as np
10
 
11
- # Import your custom utility functions
12
  from utils import (
13
  check_ocr_box,
14
  get_yolo_model,
15
  get_caption_model_processor,
16
  get_som_labeled_img,
17
  )
18
-
19
- # Load the YOLO model using the ultralytics class instead of torch.load
20
  from ultralytics import YOLO
 
21
 
22
- # Use the YOLO constructor to load the model properly
23
- yolo_model = YOLO("weights/icon_detect/best.pt")
24
- print(f"YOLO model type: {type(yolo_model)}")
25
 
26
- # Load the captioning model (Florence-2)
27
- from transformers import AutoProcessor, AutoModelForCausalLM
28
 
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
- dtype = torch.float16 if device == "cuda" else torch.float32
 
 
 
 
31
 
32
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
 
 
 
 
 
 
 
33
  try:
 
 
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  "weights/icon_caption_florence",
36
- torch_dtype=dtype,
37
- trust_remote_code=True
38
- ).to(device)
39
  except Exception as e:
40
- print(f"Error loading caption model: {str(e)}")
41
  model = AutoModelForCausalLM.from_pretrained(
42
  "weights/icon_caption_florence",
43
- torch_dtype=torch.float32,
44
- trust_remote_code=True
45
- ).to("cpu")
46
-
47
- if not hasattr(model.config, 'vision_config'):
48
- model.config.vision_config = {}
49
- if 'model_type' not in model.config.vision_config:
50
- model.config.vision_config['model_type'] = 'davit'
51
 
52
  caption_model_processor = {"processor": processor, "model": model}
53
- print("Finish loading caption model!")
54
 
55
  app = FastAPI()
56
 
 
8
  import torch
9
  import numpy as np
10
 
11
+ # Existing imports
12
  from utils import (
13
  check_ocr_box,
14
  get_yolo_model,
15
  get_caption_model_processor,
16
  get_som_labeled_img,
17
  )
 
 
18
  from ultralytics import YOLO
19
+ from transformers import AutoProcessor, AutoModelForCausalLM
20
 
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
 
 
 
25
 
26
+ # main.py (YOLO loading fix)
27
+ from utils import get_yolo_model
28
+ import torch
29
+
30
+ # Load YOLO model using official method
31
+ yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
32
 
33
+ # Handle device placement
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ if str(device) == "cuda":
36
+ yolo_model = yolo_model.cuda()
37
+ else:
38
+ yolo_model = yolo_model.cpu()
39
+
40
+ # Load caption model and processor
41
  try:
42
+ processor = AutoProcessor.from_pretrained(
43
+ "microsoft/Florence-2-base", trust_remote_code=True
44
+ )
45
  model = AutoModelForCausalLM.from_pretrained(
46
  "weights/icon_caption_florence",
47
+ torch_dtype=torch.float16,
48
+ trust_remote_code=True,
49
+ ).to("cuda")
50
  except Exception as e:
51
+ logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
52
  model = AutoModelForCausalLM.from_pretrained(
53
  "weights/icon_caption_florence",
54
+ torch_dtype=torch.float16,
55
+ trust_remote_code=True,
56
+ )
 
 
 
 
 
57
 
58
  caption_model_processor = {"processor": processor, "model": model}
59
+ logger.info("Finished loading models!!!")
60
 
61
  app = FastAPI()
62