Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
Β·
db1ee85
1
Parent(s):
469dc51
fix "operation not supported" while flagging and use different log files for different models
Browse files- llama_lora/lib/csv_logger.py +96 -0
- llama_lora/ui/inference_ui.py +20 -4
llama_lora/lib/csv_logger.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio import FlaggingCallback, utils
|
2 |
+
import csv
|
3 |
+
import datetime
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import secrets
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, List
|
9 |
+
|
10 |
+
class CSVLogger(FlaggingCallback):
|
11 |
+
"""
|
12 |
+
The default implementation of the FlaggingCallback abstract class. Each flagged
|
13 |
+
sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
|
14 |
+
Example:
|
15 |
+
import gradio as gr
|
16 |
+
def image_classifier(inp):
|
17 |
+
return {'cat': 0.3, 'dog': 0.7}
|
18 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
|
19 |
+
flagging_callback=CSVLogger())
|
20 |
+
Guides: using_flagging
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
pass
|
25 |
+
|
26 |
+
def setup(
|
27 |
+
self,
|
28 |
+
components: List[Any],
|
29 |
+
flagging_dir: str | Path,
|
30 |
+
):
|
31 |
+
self.components = components
|
32 |
+
self.flagging_dir = flagging_dir
|
33 |
+
os.makedirs(flagging_dir, exist_ok=True)
|
34 |
+
|
35 |
+
def flag(
|
36 |
+
self,
|
37 |
+
flag_data: List[Any],
|
38 |
+
flag_option: str = "",
|
39 |
+
username: str | None = None,
|
40 |
+
filename="log.csv",
|
41 |
+
) -> int:
|
42 |
+
flagging_dir = self.flagging_dir
|
43 |
+
filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename)
|
44 |
+
log_filepath = Path(flagging_dir) / filename
|
45 |
+
is_new = not Path(log_filepath).exists()
|
46 |
+
headers = [
|
47 |
+
getattr(component, "label", None) or f"component {idx}"
|
48 |
+
for idx, component in enumerate(self.components)
|
49 |
+
] + [
|
50 |
+
"flag",
|
51 |
+
"username",
|
52 |
+
"timestamp",
|
53 |
+
]
|
54 |
+
|
55 |
+
csv_data = []
|
56 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
57 |
+
save_dir = Path(
|
58 |
+
flagging_dir
|
59 |
+
) / (
|
60 |
+
getattr(component, "label", None) or f"component {idx}"
|
61 |
+
)
|
62 |
+
if utils.is_update(sample):
|
63 |
+
csv_data.append(str(sample))
|
64 |
+
else:
|
65 |
+
csv_data.append(
|
66 |
+
component.deserialize(sample, save_dir=save_dir)
|
67 |
+
if sample is not None
|
68 |
+
else ""
|
69 |
+
)
|
70 |
+
csv_data.append(flag_option)
|
71 |
+
csv_data.append(username if username is not None else "")
|
72 |
+
csv_data.append(str(datetime.datetime.now()))
|
73 |
+
|
74 |
+
try:
|
75 |
+
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
|
76 |
+
writer = csv.writer(csvfile)
|
77 |
+
if is_new:
|
78 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
79 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
80 |
+
except Exception as e:
|
81 |
+
# workaround "OSError: [Errno 95] Operation not supported" with open(log_filepath, "a") on some cloud mounted directory
|
82 |
+
random_hex = secrets.token_hex(16)
|
83 |
+
tmp_log_filepath = str(log_filepath) + f".tmp_{random_hex}"
|
84 |
+
with open(tmp_log_filepath, "a", newline="", encoding="utf-8") as csvfile:
|
85 |
+
writer = csv.writer(csvfile)
|
86 |
+
if is_new:
|
87 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
88 |
+
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
89 |
+
os.system(f"mv '{log_filepath}' '{log_filepath}.old_{random_hex}'")
|
90 |
+
os.system(f"cat '{log_filepath}.old_{random_hex}' '{tmp_log_filepath}' > '{log_filepath}'")
|
91 |
+
os.system(f"rm '{tmp_log_filepath}'")
|
92 |
+
os.system(f"rm '{log_filepath}.old_{random_hex}'")
|
93 |
+
|
94 |
+
with open(log_filepath, "r", encoding="utf-8") as csvfile:
|
95 |
+
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
96 |
+
return line_count
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -10,6 +10,7 @@ from transformers import GenerationConfig
|
|
10 |
from ..globals import Global
|
11 |
from ..models import get_model, get_tokenizer, get_device
|
12 |
from ..lib.inference import generate
|
|
|
13 |
from ..utils.data import (
|
14 |
get_available_template_names,
|
15 |
get_available_lora_model_names,
|
@@ -320,7 +321,7 @@ def inference_ui():
|
|
320 |
if not os.path.exists(flagging_dir):
|
321 |
os.makedirs(flagging_dir)
|
322 |
|
323 |
-
flag_callback =
|
324 |
flag_components = [
|
325 |
LoggingItem("Base Model"),
|
326 |
LoggingItem("Adaptor Model"),
|
@@ -366,6 +367,18 @@ def inference_ui():
|
|
366 |
json.dumps(output_for_flagging.get("generation_config", "")),
|
367 |
]
|
368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
things_that_might_timeout = []
|
370 |
|
371 |
with gr.Blocks() as inference_ui_blocks:
|
@@ -510,7 +523,8 @@ def inference_ui():
|
|
510 |
lambda d: (flag_callback.flag(
|
511 |
get_flag_callback_args(d, "Flag"),
|
512 |
flag_option="Flag",
|
513 |
-
username=None
|
|
|
514 |
), "")[1],
|
515 |
inputs=[output_for_flagging],
|
516 |
outputs=[flag_output],
|
@@ -519,7 +533,8 @@ def inference_ui():
|
|
519 |
lambda d: (flag_callback.flag(
|
520 |
get_flag_callback_args(d, "π"),
|
521 |
flag_option="Up Vote",
|
522 |
-
username=None
|
|
|
523 |
), "")[1],
|
524 |
inputs=[output_for_flagging],
|
525 |
outputs=[flag_output],
|
@@ -528,7 +543,8 @@ def inference_ui():
|
|
528 |
lambda d: (flag_callback.flag(
|
529 |
get_flag_callback_args(d, "π"),
|
530 |
flag_option="Down Vote",
|
531 |
-
username=None
|
|
|
532 |
), "")[1],
|
533 |
inputs=[output_for_flagging],
|
534 |
outputs=[flag_output],
|
|
|
10 |
from ..globals import Global
|
11 |
from ..models import get_model, get_tokenizer, get_device
|
12 |
from ..lib.inference import generate
|
13 |
+
from ..lib.csv_logger import CSVLogger
|
14 |
from ..utils.data import (
|
15 |
get_available_template_names,
|
16 |
get_available_lora_model_names,
|
|
|
321 |
if not os.path.exists(flagging_dir):
|
322 |
os.makedirs(flagging_dir)
|
323 |
|
324 |
+
flag_callback = CSVLogger()
|
325 |
flag_components = [
|
326 |
LoggingItem("Base Model"),
|
327 |
LoggingItem("Adaptor Model"),
|
|
|
367 |
json.dumps(output_for_flagging.get("generation_config", "")),
|
368 |
]
|
369 |
|
370 |
+
def get_flag_filename(output_for_flagging_str):
|
371 |
+
output_for_flagging = json.loads(output_for_flagging_str)
|
372 |
+
base_model = output_for_flagging.get("base_model", None)
|
373 |
+
adaptor_model = output_for_flagging.get("adaptor_model", None)
|
374 |
+
if adaptor_model == "None":
|
375 |
+
adaptor_model = None
|
376 |
+
if not base_model:
|
377 |
+
return "log.csv"
|
378 |
+
if not adaptor_model:
|
379 |
+
return f"log-{base_model}.csv"
|
380 |
+
return f"log-{base_model}#{adaptor_model}.csv"
|
381 |
+
|
382 |
things_that_might_timeout = []
|
383 |
|
384 |
with gr.Blocks() as inference_ui_blocks:
|
|
|
523 |
lambda d: (flag_callback.flag(
|
524 |
get_flag_callback_args(d, "Flag"),
|
525 |
flag_option="Flag",
|
526 |
+
username=None,
|
527 |
+
filename=get_flag_filename(d)
|
528 |
), "")[1],
|
529 |
inputs=[output_for_flagging],
|
530 |
outputs=[flag_output],
|
|
|
533 |
lambda d: (flag_callback.flag(
|
534 |
get_flag_callback_args(d, "π"),
|
535 |
flag_option="Up Vote",
|
536 |
+
username=None,
|
537 |
+
filename=get_flag_filename(d)
|
538 |
), "")[1],
|
539 |
inputs=[output_for_flagging],
|
540 |
outputs=[flag_output],
|
|
|
543 |
lambda d: (flag_callback.flag(
|
544 |
get_flag_callback_args(d, "π"),
|
545 |
flag_option="Down Vote",
|
546 |
+
username=None,
|
547 |
+
filename=get_flag_filename(d)
|
548 |
), "")[1],
|
549 |
inputs=[output_for_flagging],
|
550 |
outputs=[flag_output],
|