LukasHug commited on
Commit
46282cc
1 Parent(s): 76bcd78

upload demo

Browse files
app.py CHANGED
@@ -1,146 +1,122 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
- import torch
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
-
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
-
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
-
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
-
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
57
-
58
- with gr.Blocks(css=css) as demo:
59
-
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
-
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
  )
139
 
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
144
- )
 
145
 
146
- demo.queue().launch()
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ import time
5
+ import subprocess
6
+ if '/workspace' not in sys.path:
7
+ sys.path.append('/workspace')
8
+ import gradio_web_server as gws
9
+ # from llavaguard.hf_utils import set_up_env_and_token
10
+
11
+
12
+ # Execute the pip install command with additional options
13
+ # subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
14
+
15
+
16
+ def start_controller():
17
+ print("Starting the controller")
18
+ controller_command = [
19
+ sys.executable,
20
+ "-m",
21
+ "llava.serve.controller",
22
+ "--host",
23
+ "0.0.0.0",
24
+ "--port",
25
+ "10000",
26
+ ]
27
+ print(controller_command)
28
+ return subprocess.Popen(controller_command)
29
+
30
+
31
+ def start_worker(model_path: str, model_name: str, bits=16, device=0):
32
+ print(f"Starting the model worker for the model {model_path}")
33
+ # model_name = model_path.strip("/").split("/")[-1]
34
+ device = f"cuda:{device}" if isinstance(device, int) else device
35
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
36
+ if bits != 16:
37
+ model_name += f"-{bits}bit"
38
+ worker_command = [
39
+ sys.executable,
40
+ "-m",
41
+ "llava.serve.model_worker",
42
+ "--host",
43
+ "0.0.0.0",
44
+ "--controller",
45
+ "http://localhost:10000",
46
+ "--model-path",
47
+ model_path,
48
+ "--model-name",
49
+ model_name,
50
+ "--use-flash-attn",
51
+ '--device',
52
+ device
53
+ ]
54
+ if bits != 16:
55
+ worker_command += [f"--load-{bits}bit"]
56
+ print(worker_command)
57
+ return subprocess.Popen(worker_command)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--host", type=str, default="0.0.0.0")
63
+ parser.add_argument("--port", type=int)
64
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
65
+ parser.add_argument("--concurrency-count", type=int, default=5)
66
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
67
+ parser.add_argument("--share", action="store_true")
68
+ parser.add_argument("--moderate", action="store_true")
69
+ parser.add_argument("--embed", action="store_true")
70
+ gws.args = parser.parse_args()
71
+ gws.models = []
72
+
73
+ gws.title_markdown += """
74
+
75
+ ONLY WORKS WITH GPU!
76
+
77
+ Set the environment variable `model` to change the model:
78
+ ['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B),
79
+ ['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B),
80
+ ['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
81
  """
82
+ # set_up_env_and_token(read=True)
83
+ print(f"args: {gws.args}")
84
+ # set the huggingface login token
85
+ controller_proc = start_controller()
86
+ concurrency_count = int(os.getenv("concurrency_count", 5))
87
+
88
+ models = [
89
+ 'LukasHug/LlavaGuard-7B-hf',
90
+ 'LukasHug/LlavaGuard-13B-hf',
91
+ 'LukasHug/LlavaGuard-34B-hf',]
92
+ bits = int(os.getenv("bits", 16))
93
+ model = os.getenv("model", models[0])
94
+ available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
95
+ model_path, model_name = model, model.split("/")[-1]
96
+
97
+ worker_proc = start_worker(model_path, model_name, bits=bits)
98
+
99
+
100
+ # Wait for worker and controller to start
101
+ time.sleep(10)
102
 
103
+ exit_status = 0
104
+ try:
105
+ demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
106
+ demo.queue(
107
+ status_update_rate=10,
108
+ api_open=False
109
+ ).launch(
110
+ server_name=gws.args.host,
111
+ server_port=gws.args.port,
112
+ share=gws.args.share
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
 
115
+ except Exception as e:
116
+ print(e)
117
+ exit_status = 1
118
+ finally:
119
+ worker_proc.kill()
120
+ controller_proc.kill()
121
 
122
+ sys.exit(exit_status)
examples/image1.png ADDED
examples/image2.png ADDED
examples/image3.png ADDED
examples/image4.png ADDED
examples/image5.png ADDED
examples/image6.png ADDED
gradio_web_server.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from llava.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from llava.constants import LOGDIR
13
+ from llava.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+ from taxonomy import wrap_taxonomy, default_taxonomy
18
+
19
+
20
+ def clear_conv(conv):
21
+ conv.messages = []
22
+ return conv
23
+
24
+
25
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
26
+
27
+ headers = {"User-Agent": "LLaVA Client"}
28
+
29
+ no_change_btn = gr.Button()
30
+ enable_btn = gr.Button(interactive=True)
31
+ disable_btn = gr.Button(interactive=False)
32
+
33
+ priority = {
34
+ "LlavaGuard-7B": "aaaaaaa",
35
+ "LlavaGuard-13B": "aaaaaab",
36
+ "LlavaGuard-34B": "aaaaaac",
37
+ }
38
+
39
+
40
+ def get_conv_log_filename():
41
+ t = datetime.datetime.now()
42
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
43
+ return name
44
+
45
+
46
+ def get_model_list():
47
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
48
+ assert ret.status_code == 200
49
+ ret = requests.post(args.controller_url + "/list_models")
50
+ logger.info(f"get_model_list: {ret.json()}")
51
+ models = ret.json()["models"]
52
+ models.sort(key=lambda x: priority.get(x, x))
53
+ logger.info(f"Models: {models}")
54
+ return models
55
+
56
+
57
+ get_window_url_params = """
58
+ function() {
59
+ const params = new URLSearchParams(window.location.search);
60
+ url_params = Object.fromEntries(params);
61
+ console.log(url_params);
62
+ return url_params;
63
+ }
64
+ """
65
+
66
+
67
+ def load_demo(url_params, request: gr.Request):
68
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
69
+
70
+ dropdown_update = gr.Dropdown(visible=True)
71
+ if "model" in url_params:
72
+ model = url_params["model"]
73
+ if model in models:
74
+ dropdown_update = gr.Dropdown(value=model, visible=True)
75
+
76
+ state = default_conversation.copy()
77
+ return state, dropdown_update
78
+
79
+
80
+ def load_demo_refresh_model_list(request: gr.Request):
81
+ logger.info(f"load_demo. ip: {request.client.host}")
82
+ models = get_model_list()
83
+ state = default_conversation.copy()
84
+ dropdown_update = gr.Dropdown(
85
+ choices=models,
86
+ value=models[0] if len(models) > 0 else ""
87
+ )
88
+ return state, dropdown_update
89
+
90
+
91
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
92
+ with open(get_conv_log_filename(), "a") as fout:
93
+ data = {
94
+ "tstamp": round(time.time(), 4),
95
+ "type": vote_type,
96
+ "model": model_selector,
97
+ "state": state.dict(),
98
+ "ip": request.client.host,
99
+ }
100
+ fout.write(json.dumps(data) + "\n")
101
+
102
+
103
+ def upvote_last_response(state, model_selector, request: gr.Request):
104
+ logger.info(f"upvote. ip: {request.client.host}")
105
+ vote_last_response(state, "upvote", model_selector, request)
106
+ return ("",) + (disable_btn,) * 3
107
+
108
+
109
+ def downvote_last_response(state, model_selector, request: gr.Request):
110
+ logger.info(f"downvote. ip: {request.client.host}")
111
+ vote_last_response(state, "downvote", model_selector, request)
112
+ return ("",) + (disable_btn,) * 3
113
+
114
+
115
+ def flag_last_response(state, model_selector, request: gr.Request):
116
+ logger.info(f"flag. ip: {request.client.host}")
117
+ vote_last_response(state, "flag", model_selector, request)
118
+ return ("",) + (disable_btn,) * 3
119
+
120
+
121
+ def regenerate(state, image_process_mode, request: gr.Request):
122
+ logger.info(f"regenerate. ip: {request.client.host}")
123
+ state.messages[-1][-1] = None
124
+ prev_human_msg = state.messages[-2]
125
+ if type(prev_human_msg[1]) in (tuple, list):
126
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
127
+ state.skip_next = False
128
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
129
+
130
+
131
+ def clear_history(request: gr.Request):
132
+ logger.info(f"clear_history. ip: {request.client.host}")
133
+ state = default_conversation.copy()
134
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
135
+
136
+
137
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
138
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
139
+ if len(text) <= 0 or image is None:
140
+ state.skip_next = True
141
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
142
+ if args.moderate:
143
+ flagged = violates_moderation(text)
144
+ if flagged:
145
+ state.skip_next = True
146
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
147
+ no_change_btn,) * 5
148
+
149
+ text = wrap_taxonomy(text)
150
+ if image is not None:
151
+ text = text # Hard cut-off for images
152
+ if '<image>' not in text:
153
+ # text = '<Image><image></Image>' + text
154
+ text = text + '\n<image>'
155
+ text = (text, image, image_process_mode)
156
+ state = default_conversation.copy()
157
+ state = clear_conv(state)
158
+ state.append_message(state.roles[0], text)
159
+ state.append_message(state.roles[1], None)
160
+ state.skip_next = False
161
+ return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
162
+
163
+
164
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
165
+ logger.info(f"http_bot. ip: {request.client.host}")
166
+ start_tstamp = time.time()
167
+ model_name = model_selector
168
+
169
+ if state.skip_next:
170
+ # This generate call is skipped due to invalid inputs
171
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
172
+ return
173
+
174
+ if len(state.messages) == state.offset + 2:
175
+ # First round of conversation
176
+ if "llava" in model_name.lower():
177
+ if 'llama-2' in model_name.lower():
178
+ template_name = "llava_llama_2"
179
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
180
+ if 'orca' in model_name.lower():
181
+ template_name = "mistral_orca"
182
+ elif 'hermes' in model_name.lower():
183
+ template_name = "chatml_direct"
184
+ else:
185
+ template_name = "mistral_instruct"
186
+ elif 'llava-v1.6-34b' in model_name.lower():
187
+ template_name = "chatml_direct"
188
+ elif "v1" in model_name.lower():
189
+ if 'mmtag' in model_name.lower():
190
+ template_name = "v1_mmtag"
191
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
192
+ template_name = "v1_mmtag"
193
+ else:
194
+ template_name = "llava_v1"
195
+ elif "mpt" in model_name.lower():
196
+ template_name = "mpt"
197
+ else:
198
+ if 'mmtag' in model_name.lower():
199
+ template_name = "v0_mmtag"
200
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
201
+ template_name = "v0_mmtag"
202
+ else:
203
+ template_name = "llava_v0"
204
+ elif "mpt" in model_name:
205
+ template_name = "mpt_text"
206
+ elif "llama-2" in model_name:
207
+ template_name = "llama_2"
208
+ else:
209
+ template_name = "vicuna_v1"
210
+ new_state = conv_templates[template_name].copy()
211
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
212
+ new_state.append_message(new_state.roles[1], None)
213
+ state = new_state
214
+
215
+ # Query worker address
216
+ controller_url = args.controller_url
217
+ ret = requests.post(controller_url + "/get_worker_address",
218
+ json={"model": model_name})
219
+ worker_addr = ret.json()["address"]
220
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
221
+
222
+ # No available worker
223
+ if worker_addr == "":
224
+ state.messages[-1][-1] = server_error_msg
225
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
226
+ return
227
+
228
+ # Construct prompt
229
+ prompt = state.get_prompt()
230
+
231
+ all_images = state.get_images(return_pil=True)
232
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
233
+ for image, hash in zip(all_images, all_image_hash):
234
+ t = datetime.datetime.now()
235
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
236
+ if not os.path.isfile(filename):
237
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
238
+ image.save(filename)
239
+
240
+ # Make requests
241
+ pload = {
242
+ "model": model_name,
243
+ "prompt": prompt,
244
+ "temperature": float(temperature),
245
+ "top_p": float(top_p),
246
+ # "num_beams": 2,
247
+ # "top_k": 50,
248
+ "max_new_tokens": min(int(max_new_tokens), 1536),
249
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
250
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
251
+ }
252
+ logger.info(f"==== request ====\n{pload}")
253
+
254
+ pload['images'] = state.get_images()
255
+
256
+ state.messages[-1][-1] = "▌"
257
+
258
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
259
+
260
+ try:
261
+ # Stream output
262
+ response = requests.post(worker_addr + "/worker_generate_stream",
263
+ headers=headers, json=pload, stream=True, timeout=10)
264
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
265
+ if chunk:
266
+ data = json.loads(chunk.decode())
267
+ if data["error_code"] == 0:
268
+ output = data["text"][len(prompt):].strip()
269
+ state.messages[-1][-1] = output + "▌"
270
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
271
+ else:
272
+ output = data["text"] + f" (error_code: {data['error_code']})"
273
+ state.messages[-1][-1] = output
274
+ yield (state, state.to_gradio_chatbot()) + (
275
+ disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
276
+ return
277
+ time.sleep(0.03)
278
+ except requests.exceptions.RequestException as e:
279
+ state.messages[-1][-1] = server_error_msg
280
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
281
+ return
282
+
283
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
284
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
285
+
286
+ finish_tstamp = time.time()
287
+ logger.info(f"{output}")
288
+
289
+ with open(get_conv_log_filename(), "a") as fout:
290
+ data = {
291
+ "tstamp": round(finish_tstamp, 4),
292
+ "type": "chat",
293
+ "model": model_name,
294
+ "start": round(start_tstamp, 4),
295
+ "finish": round(finish_tstamp, 4),
296
+ "state": state.dict(),
297
+ "images": all_image_hash,
298
+ "ip": request.client.host,
299
+ }
300
+ fout.write(json.dumps(data) + "\n")
301
+
302
+
303
+ title_markdown = ("""
304
+ # LLAVAGUARD: VLM-based Safeguard for Vision Dataset Curation and Safety Assessment
305
+ [[Project Page](https://ml-research.github.io/human-centered-genai/projects/llavaguard/index.html)]
306
+ [[Code](https://github.com/ml-research/LlavaGuard)]
307
+ [[Model](https://huggingface.co/collections/AIML-TUDA/llavaguard-665b42e89803408ee8ec1086)]
308
+ [[Dataset](https://huggingface.co/datasets/aiml-tuda/llavaguard)]
309
+ [[LavaGuard](https://arxiv.org/abs/2406.05113)]
310
+ """)
311
+
312
+ tos_markdown = ("""
313
+ ### Terms of use
314
+ By using this service, users are required to agree to the following terms:
315
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
316
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
317
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
318
+ """)
319
+
320
+ learn_more_markdown = ("""
321
+ ### License
322
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
323
+ """)
324
+
325
+ block_css = """
326
+
327
+ #buttons button {
328
+ min-width: min(120px,100%);
329
+ }
330
+
331
+ """
332
+
333
+ taxonomies = ["Default", "Modified w/ O1 non-violating", "Default message 3"]
334
+
335
+
336
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
337
+ textbox = gr.Textbox(
338
+ label="Safety Risk Taxonomy",
339
+ show_label=True,
340
+ placeholder="Enter your safety policy here",
341
+ container=True,
342
+ value=default_taxonomy,
343
+ lines=50)
344
+ with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
345
+ state = gr.State()
346
+
347
+ if not embed_mode:
348
+ gr.Markdown(title_markdown)
349
+
350
+ with gr.Row():
351
+ with gr.Column(scale=3):
352
+ with gr.Row(elem_id="model_selector_row"):
353
+ model_selector = gr.Dropdown(
354
+ choices=models,
355
+ value=models[0] if len(models) > 0 else "",
356
+ interactive=True,
357
+ show_label=False,
358
+ container=False)
359
+
360
+ imagebox = gr.Image(type="pil", label="Image", container=False)
361
+ image_process_mode = gr.Radio(
362
+ ["Crop", "Resize", "Pad", "Default"],
363
+ value="Default",
364
+ label="Preprocess for non-square image", visible=False)
365
+
366
+ if cur_dir is None:
367
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
368
+ gr.Examples(examples=[
369
+ [f"{cur_dir}/examples/image{i}.png"] for i in range(1,6)
370
+ ], inputs=imagebox)
371
+
372
+ with gr.Accordion("Parameters", open=False) as parameter_row:
373
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
374
+ label="Temperature", )
375
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.1, interactive=True, label="Top P", )
376
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
377
+ label="Max output tokens", )
378
+
379
+ with gr.Column(scale=8):
380
+ chatbot = gr.Chatbot(
381
+ elem_id="chatbot",
382
+ label="LLavaGuard Safety Assessment",
383
+ height=650,
384
+ layout="panel",
385
+ )
386
+ with gr.Row():
387
+ with gr.Column(scale=8):
388
+ textbox.render()
389
+ with gr.Column(scale=1, min_width=50):
390
+ submit_btn = gr.Button(value="Send", variant="primary")
391
+ with gr.Row(elem_id="buttons") as button_row:
392
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
393
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
394
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
395
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
396
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
397
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
398
+
399
+ if not embed_mode:
400
+ gr.Markdown(tos_markdown)
401
+ gr.Markdown(learn_more_markdown)
402
+ url_params = gr.JSON(visible=False)
403
+
404
+ # Register listeners
405
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
406
+ upvote_btn.click(
407
+ upvote_last_response,
408
+ [state, model_selector],
409
+ [textbox, upvote_btn, downvote_btn, flag_btn]
410
+ )
411
+ downvote_btn.click(
412
+ downvote_last_response,
413
+ [state, model_selector],
414
+ [textbox, upvote_btn, downvote_btn, flag_btn]
415
+ )
416
+ flag_btn.click(
417
+ flag_last_response,
418
+ [state, model_selector],
419
+ [textbox, upvote_btn, downvote_btn, flag_btn]
420
+ )
421
+
422
+ regenerate_btn.click(
423
+ regenerate,
424
+ [state, image_process_mode],
425
+ [state, chatbot, textbox, imagebox] + btn_list
426
+ ).then(
427
+ http_bot,
428
+ [state, model_selector, temperature, top_p, max_output_tokens],
429
+ [state, chatbot] + btn_list,
430
+ concurrency_limit=concurrency_count
431
+ )
432
+
433
+ clear_btn.click(
434
+ clear_history,
435
+ None,
436
+ [state, chatbot, textbox, imagebox] + btn_list,
437
+ queue=False
438
+ )
439
+
440
+ textbox.submit(
441
+ add_text,
442
+ [state, textbox, imagebox, image_process_mode],
443
+ [state, chatbot, textbox, imagebox] + btn_list,
444
+ queue=False
445
+ ).then(
446
+ http_bot,
447
+ [state, model_selector, temperature, top_p, max_output_tokens],
448
+ [state, chatbot] + btn_list,
449
+ concurrency_limit=concurrency_count
450
+ )
451
+
452
+ submit_btn.click(
453
+ add_text,
454
+ [state, textbox, imagebox, image_process_mode],
455
+ [state, chatbot, textbox, imagebox] + btn_list
456
+ ).then(
457
+ http_bot,
458
+ [state, model_selector, temperature, top_p, max_output_tokens],
459
+ [state, chatbot] + btn_list,
460
+ concurrency_limit=concurrency_count
461
+ )
462
+
463
+ if args.model_list_mode == "once":
464
+ demo.load(
465
+ load_demo,
466
+ [url_params],
467
+ [state, model_selector],
468
+ js=get_window_url_params
469
+ )
470
+ elif args.model_list_mode == "reload":
471
+ demo.load(
472
+ load_demo_refresh_model_list,
473
+ None,
474
+ [state, model_selector],
475
+ queue=False
476
+ )
477
+ else:
478
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
479
+
480
+ return demo
481
+
482
+
483
+ if __name__ == "__main__":
484
+ parser = argparse.ArgumentParser()
485
+ parser.add_argument("--host", type=str, default="0.0.0.0")
486
+ parser.add_argument("--port", type=int)
487
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
488
+ parser.add_argument("--concurrency-count", type=int, default=16)
489
+ parser.add_argument("--model-list-mode", type=str, default="once",
490
+ choices=["once", "reload"])
491
+ parser.add_argument("--share", action="store_true")
492
+ parser.add_argument("--moderate", action="store_true")
493
+ parser.add_argument("--embed", action="store_true")
494
+ args = parser.parse_args()
495
+ logger.info(f"args: {args}")
496
+
497
+ models = get_model_list()
498
+
499
+ logger.info(args)
500
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
501
+ demo.queue(
502
+ api_open=False
503
+ ).launch(
504
+ server_name=args.host,
505
+ server_port=args.port,
506
+ share=args.share
507
+ )
taxonomy.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_taxonomy = """O1: Hate, Humiliation, Harassment:
2
+ Should not:
3
+ - Contain content that expresses hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
4
+ - Contain content that encourages discrimination or harassment against people based on those characteristics.
5
+ Can:
6
+ - Contain educational content that provides information on discrimination, humiliation, and harassment.
7
+ - Discuss historical events involving hate and harassment.
8
+ O2: Violence, Harm, or Cruelty:
9
+ Should not:
10
+ - Contain content of people planning or engaging in violence.
11
+ - Contain content that encourages people to engage in violence or cruelty, including harm to others.
12
+ Can:
13
+ - Contain educational content that provides information on violence, harm, and cruelty.
14
+ - Discuss historical events involving violence and cruelty.
15
+ O3: Sexual Content:
16
+ Should not:
17
+ - Contain sexually explicit (i.e., erotic) content.
18
+ Can:
19
+ - Contain content that is related to sexuality and sexual education.
20
+ - Provide general information regarding sexual activities.
21
+ - Provide guidance on topics related to sexual health.
22
+ - Contain content of people in a loving and affectionate manner if it is not sexually explicit content.
23
+ O4: Nudity Content:
24
+ Should not:
25
+ - Contain visible genitalia (e.g., exposed breast, nipple, vagina, vulva, or penis) or explicit nudity content.
26
+ Can:
27
+ - Contain artistic nudity if the content does not include visible genitalia.
28
+ - Contain educational content and general information on the human body.
29
+ - Provide guidance on topics related to sexual health.
30
+ O5: Criminal Planning:
31
+ Should not:
32
+ - Contain content of people planning or engaging in criminal activities like theft, kidnapping, or financial crimes.
33
+ - Contain content that encourages people to engage in criminal activities or condone criminal behavior.
34
+ Can:
35
+ - Contain content that helps people understand how crimes are committed, provided that the content is not so detailed that it could be used to provide a roadmap for future crimes.
36
+ - Contain content that educates on crime prevention and understanding criminal behavior without providing guidance for criminal activities.'
37
+ O6: Weapons or Substance Abuse:
38
+ Should not:
39
+ - Contain content of illegal weapons with the intent to promote violence or harm.
40
+ - Contain content of trafficking, consumption, or use of illegal drugs.
41
+ - Promote the use of tobacco, alcohol, and cannabis.
42
+ - Contain content about the creation or production of illegal substances.
43
+ Can:
44
+ - Contain weapons used in a legal and safe context without any indication of violence or intent to harm.
45
+ - Discussions on firearm ownership and educational content on regulated substances without promoting illegal activities.
46
+ - Contain content that provides factual information about legal drugs, including their effects and chemical composition, without promoting their use.
47
+ - Contain content that offers educational insights into the manufacturing processes of substances, ensuring that such information is general and does not include detailed instructions that could facilitate their production.
48
+ O7: Self-Harm:
49
+ Should not:
50
+ - Contain content that could encourage people to harm themselves.
51
+ - Contain content that romanticizes or condones self-harm or suicide.
52
+ Can:
53
+ - Contain educational content that informs about the prevalence, causes, and effects of self-harm and suicide, such as statistical data and discussions on mental health challenges.
54
+ - Contain resources and information that guide individuals toward professional help and support services for mental health, self-harm, and suicide prevention.
55
+ O8: Animal Cruelty:
56
+ Should not:
57
+ - Contain content that could encourage people to harm animals.
58
+ - Contain content that romanticizes or condones harming animals.
59
+ - Contain explicit animal cruelty.
60
+ Can:
61
+ - Contain content that provides general information regarding animal cruelty.
62
+ O9: Disasters or Emergencies:
63
+ Should not:
64
+ - Contain content about individuals who are deceased or have suffered harm as a result of disasters or emergencies.
65
+ - Contain content about suffering, cruelty, or exploitation of individuals during such events.
66
+ Can:
67
+ - Contain respectful news coverage and reports on emergencies, prioritizing informational value over sensationalism.
68
+ - Contain educational content that provides factual information on historical and current disasters or emergencies."""
69
+
70
+
71
+ def wrap_taxonomy(taxonomy: str) -> str:
72
+ return f"""Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
73
+ {taxonomy}""" + """
74
+ Assessment Steps:
75
+ 1. Select a safety "rating". This should be "Unsafe" if the content violates the safety policy; otherwise, it should be "Safe".
76
+ 2. Assign a policy "category". If none of the policy categories are applicable, select "NA: None applying".
77
+ 3. Provide a "rationale". Explain the reason why the content is considered safe or unsafe, referring to the description of the assigned policy category.
78
+ To provide your assessment use the following json template:
79
+ {
80
+ "rating": str<"Unsafe"|"Safe">,
81
+ "category": str<"O1: Hate, Humiliation, Harassment"|"O2: Violence, Harm, or Cruelty"|"O3: Sexual Content"|"O4: Nudity Content"|"O5: Criminal Planning"|"O6: Weapons or Substance Abuse"|"O7: Self-Harm"|"O8: Animal Cruelty"|"O9: Disasters or Emergencies"|"NA: None applying">,
82
+ "rationale": str,
83
+ }
84
+ <image>"""