zetavg commited on
Commit
db1ee85
Β·
1 Parent(s): 469dc51

fix "operation not supported" while flagging and use different log files for different models

Browse files
llama_lora/lib/csv_logger.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import FlaggingCallback, utils
2
+ import csv
3
+ import datetime
4
+ import os
5
+ import re
6
+ import secrets
7
+ from pathlib import Path
8
+ from typing import Any, List
9
+
10
+ class CSVLogger(FlaggingCallback):
11
+ """
12
+ The default implementation of the FlaggingCallback abstract class. Each flagged
13
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
14
+ Example:
15
+ import gradio as gr
16
+ def image_classifier(inp):
17
+ return {'cat': 0.3, 'dog': 0.7}
18
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
19
+ flagging_callback=CSVLogger())
20
+ Guides: using_flagging
21
+ """
22
+
23
+ def __init__(self):
24
+ pass
25
+
26
+ def setup(
27
+ self,
28
+ components: List[Any],
29
+ flagging_dir: str | Path,
30
+ ):
31
+ self.components = components
32
+ self.flagging_dir = flagging_dir
33
+ os.makedirs(flagging_dir, exist_ok=True)
34
+
35
+ def flag(
36
+ self,
37
+ flag_data: List[Any],
38
+ flag_option: str = "",
39
+ username: str | None = None,
40
+ filename="log.csv",
41
+ ) -> int:
42
+ flagging_dir = self.flagging_dir
43
+ filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename)
44
+ log_filepath = Path(flagging_dir) / filename
45
+ is_new = not Path(log_filepath).exists()
46
+ headers = [
47
+ getattr(component, "label", None) or f"component {idx}"
48
+ for idx, component in enumerate(self.components)
49
+ ] + [
50
+ "flag",
51
+ "username",
52
+ "timestamp",
53
+ ]
54
+
55
+ csv_data = []
56
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
57
+ save_dir = Path(
58
+ flagging_dir
59
+ ) / (
60
+ getattr(component, "label", None) or f"component {idx}"
61
+ )
62
+ if utils.is_update(sample):
63
+ csv_data.append(str(sample))
64
+ else:
65
+ csv_data.append(
66
+ component.deserialize(sample, save_dir=save_dir)
67
+ if sample is not None
68
+ else ""
69
+ )
70
+ csv_data.append(flag_option)
71
+ csv_data.append(username if username is not None else "")
72
+ csv_data.append(str(datetime.datetime.now()))
73
+
74
+ try:
75
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
76
+ writer = csv.writer(csvfile)
77
+ if is_new:
78
+ writer.writerow(utils.sanitize_list_for_csv(headers))
79
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
80
+ except Exception as e:
81
+ # workaround "OSError: [Errno 95] Operation not supported" with open(log_filepath, "a") on some cloud mounted directory
82
+ random_hex = secrets.token_hex(16)
83
+ tmp_log_filepath = str(log_filepath) + f".tmp_{random_hex}"
84
+ with open(tmp_log_filepath, "a", newline="", encoding="utf-8") as csvfile:
85
+ writer = csv.writer(csvfile)
86
+ if is_new:
87
+ writer.writerow(utils.sanitize_list_for_csv(headers))
88
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
89
+ os.system(f"mv '{log_filepath}' '{log_filepath}.old_{random_hex}'")
90
+ os.system(f"cat '{log_filepath}.old_{random_hex}' '{tmp_log_filepath}' > '{log_filepath}'")
91
+ os.system(f"rm '{tmp_log_filepath}'")
92
+ os.system(f"rm '{log_filepath}.old_{random_hex}'")
93
+
94
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
95
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
96
+ return line_count
llama_lora/ui/inference_ui.py CHANGED
@@ -10,6 +10,7 @@ from transformers import GenerationConfig
10
  from ..globals import Global
11
  from ..models import get_model, get_tokenizer, get_device
12
  from ..lib.inference import generate
 
13
  from ..utils.data import (
14
  get_available_template_names,
15
  get_available_lora_model_names,
@@ -320,7 +321,7 @@ def inference_ui():
320
  if not os.path.exists(flagging_dir):
321
  os.makedirs(flagging_dir)
322
 
323
- flag_callback = gr.CSVLogger()
324
  flag_components = [
325
  LoggingItem("Base Model"),
326
  LoggingItem("Adaptor Model"),
@@ -366,6 +367,18 @@ def inference_ui():
366
  json.dumps(output_for_flagging.get("generation_config", "")),
367
  ]
368
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  things_that_might_timeout = []
370
 
371
  with gr.Blocks() as inference_ui_blocks:
@@ -510,7 +523,8 @@ def inference_ui():
510
  lambda d: (flag_callback.flag(
511
  get_flag_callback_args(d, "Flag"),
512
  flag_option="Flag",
513
- username=None
 
514
  ), "")[1],
515
  inputs=[output_for_flagging],
516
  outputs=[flag_output],
@@ -519,7 +533,8 @@ def inference_ui():
519
  lambda d: (flag_callback.flag(
520
  get_flag_callback_args(d, "πŸ‘"),
521
  flag_option="Up Vote",
522
- username=None
 
523
  ), "")[1],
524
  inputs=[output_for_flagging],
525
  outputs=[flag_output],
@@ -528,7 +543,8 @@ def inference_ui():
528
  lambda d: (flag_callback.flag(
529
  get_flag_callback_args(d, "πŸ‘Ž"),
530
  flag_option="Down Vote",
531
- username=None
 
532
  ), "")[1],
533
  inputs=[output_for_flagging],
534
  outputs=[flag_output],
 
10
  from ..globals import Global
11
  from ..models import get_model, get_tokenizer, get_device
12
  from ..lib.inference import generate
13
+ from ..lib.csv_logger import CSVLogger
14
  from ..utils.data import (
15
  get_available_template_names,
16
  get_available_lora_model_names,
 
321
  if not os.path.exists(flagging_dir):
322
  os.makedirs(flagging_dir)
323
 
324
+ flag_callback = CSVLogger()
325
  flag_components = [
326
  LoggingItem("Base Model"),
327
  LoggingItem("Adaptor Model"),
 
367
  json.dumps(output_for_flagging.get("generation_config", "")),
368
  ]
369
 
370
+ def get_flag_filename(output_for_flagging_str):
371
+ output_for_flagging = json.loads(output_for_flagging_str)
372
+ base_model = output_for_flagging.get("base_model", None)
373
+ adaptor_model = output_for_flagging.get("adaptor_model", None)
374
+ if adaptor_model == "None":
375
+ adaptor_model = None
376
+ if not base_model:
377
+ return "log.csv"
378
+ if not adaptor_model:
379
+ return f"log-{base_model}.csv"
380
+ return f"log-{base_model}#{adaptor_model}.csv"
381
+
382
  things_that_might_timeout = []
383
 
384
  with gr.Blocks() as inference_ui_blocks:
 
523
  lambda d: (flag_callback.flag(
524
  get_flag_callback_args(d, "Flag"),
525
  flag_option="Flag",
526
+ username=None,
527
+ filename=get_flag_filename(d)
528
  ), "")[1],
529
  inputs=[output_for_flagging],
530
  outputs=[flag_output],
 
533
  lambda d: (flag_callback.flag(
534
  get_flag_callback_args(d, "πŸ‘"),
535
  flag_option="Up Vote",
536
+ username=None,
537
+ filename=get_flag_filename(d)
538
  ), "")[1],
539
  inputs=[output_for_flagging],
540
  outputs=[flag_output],
 
543
  lambda d: (flag_callback.flag(
544
  get_flag_callback_args(d, "πŸ‘Ž"),
545
  flag_option="Down Vote",
546
+ username=None,
547
+ filename=get_flag_filename(d)
548
  ), "")[1],
549
  inputs=[output_for_flagging],
550
  outputs=[flag_output],