Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Review DB management (#1)
Browse files- Review DB management (6c5b52b18656f4479762e1da53453110c56dcd18)
Co-authored-by: Lucain Pouget <Wauplin@users.noreply.huggingface.co>
app.py
CHANGED
@@ -7,8 +7,48 @@ import sqlite3
|
|
7 |
from datasets import load_dataset
|
8 |
import threading
|
9 |
import time
|
10 |
-
from
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the TTS Arena."
|
13 |
DESCR = """
|
14 |
# TTS Arena
|
@@ -25,11 +65,11 @@ INSTR = """
|
|
25 |
**When you're ready to begin, click the Start button below!** The model names will be revealed once you vote.
|
26 |
""".strip()
|
27 |
request = ''
|
28 |
-
if
|
29 |
request = f"""
|
30 |
### Request Model
|
31 |
|
32 |
-
Please fill out [this form](https://huggingface.co/spaces/{
|
33 |
"""
|
34 |
ABOUT = f"""
|
35 |
## About
|
@@ -57,28 +97,29 @@ A list of the models, based on how highly they are ranked!
|
|
57 |
""".strip()
|
58 |
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
def del_db(txt):
|
66 |
if not txt.lower() == 'delete db':
|
67 |
raise gr.Error('You did not enter "delete db"')
|
68 |
-
|
69 |
-
|
70 |
-
)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
repo_id=os.getenv('DATASET_ID'),
|
76 |
-
repo_type='dataset'
|
77 |
-
)
|
78 |
return 'Delete DB'
|
|
|
79 |
theme = gr.themes.Base(
|
80 |
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
81 |
)
|
|
|
82 |
model_names = {
|
83 |
'styletts2': 'StyleTTS 2',
|
84 |
'tacotron': 'Tacotron',
|
@@ -126,14 +167,15 @@ model_licenses = {
|
|
126 |
'speecht5': 'MIT',
|
127 |
}
|
128 |
# def get_random_split(existing_split=None):
|
129 |
-
# choice = random.choice(list(
|
130 |
# if existing_split and choice == existing_split:
|
131 |
# return get_random_split(choice)
|
132 |
# else:
|
133 |
# return choice
|
134 |
def get_db():
|
135 |
-
return sqlite3.connect(
|
136 |
-
|
|
|
137 |
conn = get_db()
|
138 |
cursor = conn.cursor()
|
139 |
cursor.execute('''
|
@@ -152,7 +194,7 @@ def create_db():
|
|
152 |
);
|
153 |
''')
|
154 |
|
155 |
-
def
|
156 |
conn = get_db()
|
157 |
cursor = conn.cursor()
|
158 |
cursor.execute('SELECT name, upvote, downvote FROM model WHERE (upvote + downvote) > 5')
|
@@ -193,10 +235,10 @@ def upvote_model(model, uname):
|
|
193 |
if cursor.rowcount == 0:
|
194 |
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
|
195 |
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
|
196 |
-
|
|
|
197 |
cursor.close()
|
198 |
|
199 |
-
|
200 |
def downvote_model(model, uname):
|
201 |
conn = get_db()
|
202 |
cursor = conn.cursor()
|
@@ -204,8 +246,10 @@ def downvote_model(model, uname):
|
|
204 |
if cursor.rowcount == 0:
|
205 |
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
|
206 |
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
|
207 |
-
|
|
|
208 |
cursor.close()
|
|
|
209 |
def a_is_better(model1, model2, profile: gr.OAuthProfile | None):
|
210 |
if not profile:
|
211 |
raise gr.Error(MUST_BE_LOGGEDIN)
|
@@ -236,8 +280,8 @@ def both_good(model1, model2, profile: gr.OAuthProfile | None):
|
|
236 |
return reload(model1, model2)
|
237 |
def reload(chosenmodel1=None, chosenmodel2=None):
|
238 |
# Select random splits
|
239 |
-
row = random.choice(list(
|
240 |
-
options = list(random.choice(list(
|
241 |
split1, split2 = random.sample(options, 2)
|
242 |
choice1, choice2 = (row[split1], row[split2])
|
243 |
if chosenmodel1 in model_names:
|
@@ -256,11 +300,11 @@ def reload(chosenmodel1=None, chosenmodel2=None):
|
|
256 |
|
257 |
with gr.Blocks() as leaderboard:
|
258 |
gr.Markdown(LDESC)
|
259 |
-
# df = gr.Dataframe(interactive=False, value=
|
260 |
df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 75, 50])
|
261 |
reloadbtn = gr.Button("Refresh")
|
262 |
-
leaderboard.load(
|
263 |
-
reloadbtn.click(
|
264 |
gr.Markdown("DISCLAIMER: The licenses listed may not be accurate or up to date, you are responsible for checking the licenses before using the models. Also note that some models may have additional usage restrictions.")
|
265 |
|
266 |
with gr.Blocks() as vote:
|
@@ -310,8 +354,8 @@ with gr.Blocks() as vote:
|
|
310 |
with gr.Blocks() as about:
|
311 |
gr.Markdown(ABOUT)
|
312 |
with gr.Blocks() as admin:
|
313 |
-
rdb = gr.Button("Reload Dataset")
|
314 |
-
rdb.click(
|
315 |
with gr.Group():
|
316 |
dbtext = gr.Textbox(label="Type \"delete db\" to confirm", placeholder="delete db")
|
317 |
ddb = gr.Button("Delete DB")
|
@@ -319,57 +363,5 @@ with gr.Blocks() as admin:
|
|
319 |
with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="TTS Leaderboard") as demo:
|
320 |
gr.Markdown(DESCR)
|
321 |
gr.TabbedInterface([vote, leaderboard, about, admin], ['Vote', 'Leaderboard', 'About', 'Admin (ONLY IN BETA)'])
|
322 |
-
|
323 |
-
api = HfApi(
|
324 |
-
token=os.getenv('HF_TOKEN')
|
325 |
-
)
|
326 |
-
time.sleep(60 * 60) # Every hour
|
327 |
-
print("Syncing DB before restarting space")
|
328 |
-
api.upload_file(
|
329 |
-
path_or_fileobj='database.db',
|
330 |
-
path_in_repo='database.db',
|
331 |
-
repo_id=os.getenv('DATASET_ID'),
|
332 |
-
repo_type='dataset'
|
333 |
-
)
|
334 |
-
print("Restarting space")
|
335 |
-
api.restart_space(repo_id=os.getenv('HF_ID'))
|
336 |
-
def sync_db():
|
337 |
-
api = HfApi(
|
338 |
-
token=os.getenv('HF_TOKEN')
|
339 |
-
)
|
340 |
-
while True:
|
341 |
-
time.sleep(60 * 10)
|
342 |
-
print("Uploading DB")
|
343 |
-
api.upload_file(
|
344 |
-
path_or_fileobj='database.db',
|
345 |
-
path_in_repo='database.db',
|
346 |
-
repo_id=os.getenv('DATASET_ID'),
|
347 |
-
repo_type='dataset'
|
348 |
-
)
|
349 |
-
if os.getenv('HF_ID'):
|
350 |
-
restart_thread = threading.Thread(target=restart_space)
|
351 |
-
restart_thread.daemon = True
|
352 |
-
restart_thread.start()
|
353 |
-
if os.getenv('DATASET_ID'):
|
354 |
-
# Fetch DB
|
355 |
-
api = HfApi(
|
356 |
-
token=os.getenv('HF_TOKEN')
|
357 |
-
)
|
358 |
-
print("Downloading DB...")
|
359 |
-
try:
|
360 |
-
path = api.hf_hub_download(
|
361 |
-
repo_id=os.getenv('DATASET_ID'),
|
362 |
-
repo_type='dataset',
|
363 |
-
filename='database.db',
|
364 |
-
cache_dir='./'
|
365 |
-
)
|
366 |
-
shutil.copyfile(path, 'database.db')
|
367 |
-
print("Downloaded DB")
|
368 |
-
except:
|
369 |
-
pass
|
370 |
-
# Update DB
|
371 |
-
db_thread = threading.Thread(target=sync_db)
|
372 |
-
db_thread.daemon = True
|
373 |
-
db_thread.start()
|
374 |
-
create_db()
|
375 |
demo.queue(api_open=False).launch(show_api=False)
|
|
|
7 |
from datasets import load_dataset
|
8 |
import threading
|
9 |
import time
|
10 |
+
from pathlib import Path
|
11 |
+
from huggingface_hub import CommitScheduler, delete_file, hf_hub_download
|
12 |
|
13 |
+
SPACE_ID = os.getenv('HF_ID')
|
14 |
+
|
15 |
+
DB_DATASET_ID = os.getenv('DATASET_ID')
|
16 |
+
DB_NAME = "database.db"
|
17 |
+
DB_PATH = "database.db"
|
18 |
+
|
19 |
+
AUDIO_DATASET_ID = "ttseval/tts-arena-new"
|
20 |
+
|
21 |
+
####################################
|
22 |
+
# Space initialization
|
23 |
+
####################################
|
24 |
+
|
25 |
+
# Download existing DB
|
26 |
+
print("Downloading DB...")
|
27 |
+
try:
|
28 |
+
cache_path = hf_hub_download(repo_id=DB_DATASET_ID, repo_type='dataset', filename=DB_NAME)
|
29 |
+
shutil.copyfile(cache_path, DB_PATH)
|
30 |
+
print("Downloaded DB")
|
31 |
+
except Exception as e:
|
32 |
+
print("Error while downloading DB:", e)
|
33 |
+
|
34 |
+
# Create DB table (if doesn't exist)
|
35 |
+
create_db_if_missing()
|
36 |
+
|
37 |
+
# Sync local DB with remote repo every 5 minute (only if a change is detected)
|
38 |
+
scheduler = CommitScheduler(
|
39 |
+
repo_id=DB_DATASET_ID,
|
40 |
+
repo_type="dataset",
|
41 |
+
folder_path=Path(DB_PATH).parent,
|
42 |
+
every=5,
|
43 |
+
allow_patterns=DB_NAME,
|
44 |
+
)
|
45 |
+
|
46 |
+
# Load audio dataset
|
47 |
+
audio_dataset = load_dataset(AUDIO_DATASET_ID)
|
48 |
+
|
49 |
+
####################################
|
50 |
+
# Gradio app
|
51 |
+
####################################
|
52 |
MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the TTS Arena."
|
53 |
DESCR = """
|
54 |
# TTS Arena
|
|
|
65 |
**When you're ready to begin, click the Start button below!** The model names will be revealed once you vote.
|
66 |
""".strip()
|
67 |
request = ''
|
68 |
+
if SPACE_ID:
|
69 |
request = f"""
|
70 |
### Request Model
|
71 |
|
72 |
+
Please fill out [this form](https://huggingface.co/spaces/{SPACE_ID}/discussions/new?title=%5BModel+Request%5D+&description=%23%23%20Model%20Request%0A%0A%2A%2AModel%20website%2Fpaper%20%28if%20applicable%29%2A%2A%3A%0A%2A%2AModel%20available%20on%2A%2A%3A%20%28coqui%7CHF%20pipeline%7Ccustom%20code%29%0A%2A%2AWhy%20do%20you%20want%20this%20model%20added%3F%2A%2A%0A%2A%2AComments%3A%2A%2A) to request a model.
|
73 |
"""
|
74 |
ABOUT = f"""
|
75 |
## About
|
|
|
97 |
""".strip()
|
98 |
|
99 |
|
100 |
+
|
101 |
+
|
102 |
+
def reload_audio_dataset():
|
103 |
+
global audio_dataset
|
104 |
+
audio_dataset = load_dataset(AUDIO_DATASET_ID)
|
105 |
+
return 'Reload audio dataset'
|
106 |
+
|
107 |
def del_db(txt):
|
108 |
if not txt.lower() == 'delete db':
|
109 |
raise gr.Error('You did not enter "delete db"')
|
110 |
+
|
111 |
+
# Delete local + remote
|
112 |
+
os.remove(DB_PATH)
|
113 |
+
delete_file(path_in_repo=DB_NAME, repo_id=DATASET_ID, repo_type='dataset')
|
114 |
+
|
115 |
+
# Recreate
|
116 |
+
create_db_if_missing()
|
|
|
|
|
|
|
117 |
return 'Delete DB'
|
118 |
+
|
119 |
theme = gr.themes.Base(
|
120 |
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
121 |
)
|
122 |
+
|
123 |
model_names = {
|
124 |
'styletts2': 'StyleTTS 2',
|
125 |
'tacotron': 'Tacotron',
|
|
|
167 |
'speecht5': 'MIT',
|
168 |
}
|
169 |
# def get_random_split(existing_split=None):
|
170 |
+
# choice = random.choice(list(audio_dataset.keys()))
|
171 |
# if existing_split and choice == existing_split:
|
172 |
# return get_random_split(choice)
|
173 |
# else:
|
174 |
# return choice
|
175 |
def get_db():
|
176 |
+
return sqlite3.connect(DB_PATH)
|
177 |
+
|
178 |
+
def create_db_if_missing():
|
179 |
conn = get_db()
|
180 |
cursor = conn.cursor()
|
181 |
cursor.execute('''
|
|
|
194 |
);
|
195 |
''')
|
196 |
|
197 |
+
def get_leaderboard():
|
198 |
conn = get_db()
|
199 |
cursor = conn.cursor()
|
200 |
cursor.execute('SELECT name, upvote, downvote FROM model WHERE (upvote + downvote) > 5')
|
|
|
235 |
if cursor.rowcount == 0:
|
236 |
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
|
237 |
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
|
238 |
+
with scheduler.lock:
|
239 |
+
conn.commit()
|
240 |
cursor.close()
|
241 |
|
|
|
242 |
def downvote_model(model, uname):
|
243 |
conn = get_db()
|
244 |
cursor = conn.cursor()
|
|
|
246 |
if cursor.rowcount == 0:
|
247 |
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
|
248 |
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
|
249 |
+
with scheduler.lock:
|
250 |
+
conn.commit()
|
251 |
cursor.close()
|
252 |
+
|
253 |
def a_is_better(model1, model2, profile: gr.OAuthProfile | None):
|
254 |
if not profile:
|
255 |
raise gr.Error(MUST_BE_LOGGEDIN)
|
|
|
280 |
return reload(model1, model2)
|
281 |
def reload(chosenmodel1=None, chosenmodel2=None):
|
282 |
# Select random splits
|
283 |
+
row = random.choice(list(audio_dataset['train']))
|
284 |
+
options = list(random.choice(list(audio_dataset['train'])).keys())
|
285 |
split1, split2 = random.sample(options, 2)
|
286 |
choice1, choice2 = (row[split1], row[split2])
|
287 |
if chosenmodel1 in model_names:
|
|
|
300 |
|
301 |
with gr.Blocks() as leaderboard:
|
302 |
gr.Markdown(LDESC)
|
303 |
+
# df = gr.Dataframe(interactive=False, value=get_leaderboard())
|
304 |
df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 75, 50])
|
305 |
reloadbtn = gr.Button("Refresh")
|
306 |
+
leaderboard.load(get_leaderboard, outputs=[df])
|
307 |
+
reloadbtn.click(get_leaderboard, outputs=[df])
|
308 |
gr.Markdown("DISCLAIMER: The licenses listed may not be accurate or up to date, you are responsible for checking the licenses before using the models. Also note that some models may have additional usage restrictions.")
|
309 |
|
310 |
with gr.Blocks() as vote:
|
|
|
354 |
with gr.Blocks() as about:
|
355 |
gr.Markdown(ABOUT)
|
356 |
with gr.Blocks() as admin:
|
357 |
+
rdb = gr.Button("Reload Audio Dataset")
|
358 |
+
rdb.click(reload_audio_dataset, outputs=rdb)
|
359 |
with gr.Group():
|
360 |
dbtext = gr.Textbox(label="Type \"delete db\" to confirm", placeholder="delete db")
|
361 |
ddb = gr.Button("Delete DB")
|
|
|
363 |
with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="TTS Leaderboard") as demo:
|
364 |
gr.Markdown(DESCR)
|
365 |
gr.TabbedInterface([vote, leaderboard, about, admin], ['Vote', 'Leaderboard', 'About', 'Admin (ONLY IN BETA)'])
|
366 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
demo.queue(api_open=False).launch(show_api=False)
|