tdurbor commited on
Commit
97067cd
1 Parent(s): b58cdd5

Cleanup app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -44
app.py CHANGED
@@ -5,13 +5,10 @@ import random
5
  import logging
6
  import threading
7
  from pathlib import Path
8
- from uuid import uuid4
9
- from typing import Tuple
10
  from datetime import datetime, timedelta
11
 
12
  import numpy as np
13
  import gradio as gr
14
- from PIL import Image
15
  from dotenv import load_dotenv
16
  from datasets import load_dataset
17
  from huggingface_hub import CommitScheduler
@@ -24,32 +21,29 @@ from db import (
24
  fill_database_once
25
  )
26
 
 
 
27
  token = os.getenv("HUGGINGFACE_HUB_TOKEN")
28
 
29
- # Load datasets
30
- dataset = load_dataset("bgsys/background-removal-arena-green", split='train')
31
- fill_database_once()
32
-
33
  # Configure logging
34
  logging.basicConfig(level=logging.INFO)
35
 
36
- # Load environment variables from .env file
37
- load_dotenv()
 
38
 
39
- # Directory and path setup for JSON dataset
40
  JSON_DATASET_DIR = Path("data/json_dataset")
41
  JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
42
 
43
- # Initialize CommitScheduler for Hugging Face only if running in space
44
- scheduler = None
45
- if is_running_in_space():
46
- scheduler = CommitScheduler(
47
- repo_id="bgsys/votes_datasets_test2",
48
- repo_type="dataset",
49
- folder_path=JSON_DATASET_DIR,
50
- path_in_repo="data",
51
- token=token
52
- )
53
 
54
  def fetch_elo_scores():
55
  """Fetch and log Elo scores."""
@@ -63,21 +57,15 @@ def fetch_elo_scores():
63
 
64
  def update_rankings_table():
65
  """Update and return the rankings table based on Elo scores."""
66
- elo_scores = fetch_elo_scores()
67
- if elo_scores:
68
- rankings = [
69
- ["Photoroom", int(elo_scores.get("Photoroom", 1000))],
70
- ["RemoveBG", int(elo_scores.get("RemoveBG", 1000))],
71
- ["BRIA RMBG 2.0", int(elo_scores.get("BRIA RMBG 2.0", 1000))],
72
- ]
73
- rankings.sort(key=lambda x: x[1], reverse=True)
74
- return rankings
75
- else:
76
- return [
77
- ["Photoroom", -1],
78
- ["RemoveBG", -1],
79
- ["BRIA RMBG 2.0", -1],
80
- ]
81
 
82
  def select_new_image():
83
  """Select a new image and its segmented versions."""
@@ -95,8 +83,7 @@ def select_new_image():
95
  sample = dataset[random_index]
96
  input_image = sample['original_image']
97
 
98
- segmented_images = [sample['clipdrop_image'], sample['bria_image'],
99
- sample['photoroom_image'], sample['removebg_image']]
100
  segmented_sources = ['Clipdrop', 'BRIA RMBG 2.0', 'Photoroom', 'RemoveBG']
101
 
102
  if segmented_images.count(None) > 2:
@@ -107,11 +94,11 @@ def select_new_image():
107
  try:
108
  selected_indices = random.sample([i for i, img in enumerate(segmented_images) if img is not None], 2)
109
  model_a_index, model_b_index = selected_indices
110
- model_a_output_image = segmented_images[model_a_index]
111
- model_b_output_image = segmented_images[model_b_index]
112
- model_a_name = segmented_sources[model_a_index]
113
- model_b_name = segmented_sources[model_b_index]
114
- return sample['original_filename'], input_image, model_a_output_image, model_b_output_image, model_a_name, model_b_name
115
  except Exception as e:
116
  logging.error("Error processing images: %s. Resampling another image.", str(e))
117
  last_image_index = random_index
@@ -164,10 +151,10 @@ def gradio_interface():
164
  with gr.Tab("⚔️ Arena (battle)", id=0):
165
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
166
 
167
- filname, input_image, segmented_a, segmented_b, a_name, b_name = select_new_image()
168
  model_a_name = gr.State(a_name)
169
  model_b_name = gr.State(b_name)
170
- fpath_input = gr.State(filname)
171
 
172
  # Compute the absolute difference between the masks
173
  mask_difference = compute_mask_difference(segmented_a, segmented_b)
 
5
  import logging
6
  import threading
7
  from pathlib import Path
 
 
8
  from datetime import datetime, timedelta
9
 
10
  import numpy as np
11
  import gradio as gr
 
12
  from dotenv import load_dotenv
13
  from datasets import load_dataset
14
  from huggingface_hub import CommitScheduler
 
21
  fill_database_once
22
  )
23
 
24
+ # Load environment variables
25
+ load_dotenv()
26
  token = os.getenv("HUGGINGFACE_HUB_TOKEN")
27
 
 
 
 
 
28
  # Configure logging
29
  logging.basicConfig(level=logging.INFO)
30
 
31
+ # Load datasets and initialize database
32
+ dataset = load_dataset("bgsys/background-removal-arena-green", split='train')
33
+ fill_database_once()
34
 
35
+ # Directory setup for JSON dataset
36
  JSON_DATASET_DIR = Path("data/json_dataset")
37
  JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
38
 
39
+ # Initialize CommitScheduler if running in space
40
+ scheduler = CommitScheduler(
41
+ repo_id="bgsys/votes_datasets_test2",
42
+ repo_type="dataset",
43
+ folder_path=JSON_DATASET_DIR,
44
+ path_in_repo="data",
45
+ token=token
46
+ ) if is_running_in_space() else None
 
 
47
 
48
  def fetch_elo_scores():
49
  """Fetch and log Elo scores."""
 
57
 
58
  def update_rankings_table():
59
  """Update and return the rankings table based on Elo scores."""
60
+ elo_scores = fetch_elo_scores() or {}
61
+ default_score = 1000
62
+ rankings = [
63
+ ["Photoroom", int(elo_scores.get("Photoroom", default_score))],
64
+ ["RemoveBG", int(elo_scores.get("RemoveBG", default_score))],
65
+ ["BRIA RMBG 2.0", int(elo_scores.get("BRIA RMBG 2.0", default_score))],
66
+ ]
67
+ rankings.sort(key=lambda x: x[1], reverse=True)
68
+ return rankings
 
 
 
 
 
 
69
 
70
  def select_new_image():
71
  """Select a new image and its segmented versions."""
 
83
  sample = dataset[random_index]
84
  input_image = sample['original_image']
85
 
86
+ segmented_images = [sample.get(key) for key in ['clipdrop_image', 'bria_image', 'photoroom_image', 'removebg_image']]
 
87
  segmented_sources = ['Clipdrop', 'BRIA RMBG 2.0', 'Photoroom', 'RemoveBG']
88
 
89
  if segmented_images.count(None) > 2:
 
94
  try:
95
  selected_indices = random.sample([i for i, img in enumerate(segmented_images) if img is not None], 2)
96
  model_a_index, model_b_index = selected_indices
97
+ return (
98
+ sample['original_filename'], input_image,
99
+ segmented_images[model_a_index], segmented_images[model_b_index],
100
+ segmented_sources[model_a_index], segmented_sources[model_b_index]
101
+ )
102
  except Exception as e:
103
  logging.error("Error processing images: %s. Resampling another image.", str(e))
104
  last_image_index = random_index
 
151
  with gr.Tab("⚔️ Arena (battle)", id=0):
152
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
153
 
154
+ filename, input_image, segmented_a, segmented_b, a_name, b_name = select_new_image()
155
  model_a_name = gr.State(a_name)
156
  model_b_name = gr.State(b_name)
157
+ fpath_input = gr.State(filename)
158
 
159
  # Compute the absolute difference between the masks
160
  mask_difference = compute_mask_difference(segmented_a, segmented_b)