atlury commited on
Commit
e21d024
·
verified ·
1 Parent(s): 113f0fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -27
app.py CHANGED
@@ -23,52 +23,54 @@ ENTITIES_COLORS = {
23
  }
24
  BOX_PADDING = 2
25
 
26
- # Load pre-trained YOLOv8 model
27
- model_path_1 = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
28
- model_path_2 = "models/dla-model.pt"
29
-
30
- if not os.path.exists(model_path_1):
31
- # Download the model file if it doesn't exist
32
- model_url_1 = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
33
- response = requests.get(model_url_1)
34
- with open(model_path_1, "wb") as f:
35
- f.write(response.content)
36
 
37
- if not os.path.exists(model_path_2):
38
- # Assume the second model file is manually uploaded in the specified path
39
- pass
 
 
 
 
 
 
40
 
41
  # Load models
42
- model_1 = YOLO(model_path_1)
43
- model_2 = YOLO(model_path_2)
44
 
45
- # Get class names from the first model
46
- class_names_1 = model_1.names
47
- class_names_2 = list(ENTITIES_COLORS.keys())
48
 
49
  @spaces.GPU(duration=60)
50
  def process_image(image, model_choice):
51
  try:
52
- if model_choice == "YOLOv8 Model":
53
- # Use the first model
54
- results = model_1(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
 
55
  result = results[0]
56
 
57
  # Extract annotated image and labels with class names
58
  annotated_image = result.plot()
59
 
60
  detected_areas_labels = "\n".join([
61
- f"{class_names_1[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
62
  ])
63
 
64
  return annotated_image, detected_areas_labels
65
 
66
  elif model_choice == "DLA Model":
67
- # Use the second model
68
  image_path = "input_image.jpg" # Temporary save the uploaded image
69
  cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
70
  image = cv2.imread(image_path)
71
- results = model_2.predict(source=image, conf=0.2, iou=0.8)
72
  boxes = results[0].boxes
73
 
74
  if len(boxes) == 0:
@@ -76,7 +78,7 @@ def process_image(image, model_choice):
76
 
77
  for box in boxes:
78
  detection_class_conf = round(box.conf.item(), 2)
79
- cls = class_names_2[int(box.cls)]
80
  start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
81
  end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))
82
 
@@ -98,7 +100,7 @@ def process_image(image, model_choice):
98
  start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
99
  image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
100
 
101
- return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names_2)
102
 
103
  else:
104
  return None, "Invalid model choice"
@@ -114,7 +116,7 @@ with gr.Blocks() as demo:
114
  input_image = gr.Image(type="pil", label="Upload Image")
115
  output_image = gr.Image(type="pil", label="Annotated Image")
116
 
117
- model_choice = gr.Dropdown(["YOLOv8 Model", "DLA Model"], label="Select Model", value="YOLOv8 Model", scale=0.5)
118
  output_text = gr.Textbox(label="Detected Areas and Labels")
119
 
120
  btn = gr.Button("Run Document Segmentation")
 
23
  }
24
  BOX_PADDING = 2
25
 
26
+ # Load pre-trained YOLOv8 models
27
+ model_paths = {
28
+ "YOLOv8x Model": "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt",
29
+ "YOLOv8m Model": "yolov8m-doclaynet.pt",
30
+ "YOLOv8n Model": "yolov8n-doclaynet.pt",
31
+ "DLA Model": "models/dla-model.pt"
32
+ }
 
 
 
33
 
34
+ # Ensure the model files are in the correct location
35
+ for model_name, model_path in model_paths.items():
36
+ if not os.path.exists(model_path):
37
+ # For demonstration, we only download the YOLOv8x model
38
+ if model_name == "YOLOv8x Model":
39
+ model_url = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
40
+ response = requests.get(model_url)
41
+ with open(model_path, "wb") as f:
42
+ f.write(response.content)
43
 
44
  # Load models
45
+ models = {name: YOLO(path) for name, path in model_paths.items()}
 
46
 
47
+ # Get class names from the YOLOv8 models
48
+ class_names = list(ENTITIES_COLORS.keys())
 
49
 
50
  @spaces.GPU(duration=60)
51
  def process_image(image, model_choice):
52
  try:
53
+ if "YOLOv8" in model_choice:
54
+ # Use the selected YOLOv8 model
55
+ model = models[model_choice]
56
+ results = model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
57
  result = results[0]
58
 
59
  # Extract annotated image and labels with class names
60
  annotated_image = result.plot()
61
 
62
  detected_areas_labels = "\n".join([
63
+ f"{class_names[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
64
  ])
65
 
66
  return annotated_image, detected_areas_labels
67
 
68
  elif model_choice == "DLA Model":
69
+ # Use the DLA model
70
  image_path = "input_image.jpg" # Temporary save the uploaded image
71
  cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
72
  image = cv2.imread(image_path)
73
+ results = models[model_choice].predict(source=image, conf=0.2, iou=0.8)
74
  boxes = results[0].boxes
75
 
76
  if len(boxes) == 0:
 
78
 
79
  for box in boxes:
80
  detection_class_conf = round(box.conf.item(), 2)
81
+ cls = class_names[int(box.cls)]
82
  start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
83
  end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))
84
 
 
100
  start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
101
  image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
102
 
103
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names)
104
 
105
  else:
106
  return None, "Invalid model choice"
 
116
  input_image = gr.Image(type="pil", label="Upload Image")
117
  output_image = gr.Image(type="pil", label="Annotated Image")
118
 
119
+ model_choice = gr.Dropdown(list(model_paths.keys()), label="Select Model", value="YOLOv8x Model", scale=0.5)
120
  output_text = gr.Textbox(label="Detected Areas and Labels")
121
 
122
  btn = gr.Button("Run Document Segmentation")