jadechoghari commited on
Commit
714db0a
·
1 Parent(s): 1342b13
__pycache__/gradio_css.cpython-310.pyc CHANGED
Binary files a/__pycache__/gradio_css.cpython-310.pyc and b/__pycache__/gradio_css.cpython-310.pyc differ
 
__pycache__/inference.cpython-310.pyc CHANGED
Binary files a/__pycache__/inference.cpython-310.pyc and b/__pycache__/inference.cpython-310.pyc differ
 
__pycache__/model_UI.cpython-310.pyc ADDED
Binary file (8.79 kB). View file
 
app.py CHANGED
@@ -1,8 +1,3 @@
1
- '''
2
- Usage:
3
-
4
- python -m ferret.serve.gradio_web_server --controller http://localhost:10000 --add_region_feature
5
- '''
6
  import argparse
7
  import datetime
8
  import json
@@ -11,109 +6,29 @@ import time
11
 
12
  import gradio as gr
13
  import requests
14
-
15
  from conversation import (default_conversation, conv_templates,
16
  SeparatorStyle)
17
- from constants import LOGDIR
 
18
  from utils import (build_logger, server_error_msg,
19
- violates_moderation, moderation_msg)
20
  import hashlib
21
- # Added
22
- import re
23
- from copy import deepcopy
24
- from PIL import ImageDraw, ImageFont
25
- from gradio import processing_utils
26
- import numpy as np
27
- import torch
28
- import torch.nn.functional as F
29
- from scipy.ndimage import binary_dilation, binary_erosion
30
- import pdb
31
- from gradio_css import code_highlight_css
32
  import spaces
33
 
34
- from inference import inference_and_run
35
-
36
- DEFAULT_REGION_REFER_TOKEN = "[region]"
37
- DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
38
-
39
-
40
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
41
 
42
- headers = {"User-Agent": "FERRET Client"}
43
 
44
- no_change_btn = gr.Button.update()
45
- enable_btn = gr.Button.update(interactive=True)
46
- disable_btn = gr.Button.update(interactive=False)
47
 
48
  priority = {
49
  "vicuna-13b": "aaaaaaa",
50
  "koala-13b": "aaaaaab",
51
  }
52
 
53
- VOCAB_IMAGE_W = 1000 # 224
54
- VOCAB_IMAGE_H = 1000 # 224
55
-
56
- def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
57
- if mask is not None:
58
- assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
59
- coor_mask = torch.zeros((raw_w, raw_h))
60
- # Assume it samples a point.
61
- if len(coor) == 2:
62
- # Define window size
63
- span = 5
64
- # Make sure the window does not exceed array bounds
65
- x_min = max(0, coor[0] - span)
66
- x_max = min(raw_w, coor[0] + span + 1)
67
- y_min = max(0, coor[1] - span)
68
- y_max = min(raw_h, coor[1] + span + 1)
69
- coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
70
- assert (coor_mask==1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
71
- elif len(coor) == 4:
72
- # Box input or Sketch input.
73
- coor_mask = torch.zeros((raw_w, raw_h))
74
- coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
75
- if mask is not None:
76
- coor_mask = coor_mask * mask
77
- # coor_mask = torch.from_numpy(coor_mask)
78
- # pdb.set_trace()
79
- assert len(coor_mask.nonzero()) != 0
80
- return coor_mask.tolist()
81
-
82
-
83
- def draw_box(coor, region_mask, region_ph, img, input_mode):
84
- colors = ["red"]
85
- draw = ImageDraw.Draw(img)
86
- font = ImageFont.truetype("./DejaVuSans.ttf", size=18)
87
- if input_mode == 'Box':
88
- draw.rectangle([coor[0], coor[1], coor[2], coor[3]], outline=colors[0], width=4)
89
- draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
90
- draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
91
- elif input_mode == 'Point':
92
- r = 8
93
- leftUpPoint = (coor[0]-r, coor[1]-r)
94
- rightDownPoint = (coor[0]+r, coor[1]+r)
95
- twoPointList = [leftUpPoint, rightDownPoint]
96
- draw.ellipse(twoPointList, outline=colors[0], width=4)
97
- draw.rectangle([coor[0], coor[1], coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[1] + int(font.size * 1.2)], outline=colors[0], fill=colors[0], width=4)
98
- draw.text([coor[0] + int(font.size * 0.2), coor[1]], region_ph, font=font, fill=(255,255,255))
99
- elif input_mode == 'Sketch':
100
- draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
101
- draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
102
- # Use morphological operations to find the boundary
103
- mask = np.array(region_mask)
104
- dilated = binary_dilation(mask, structure=np.ones((3,3)))
105
- eroded = binary_erosion(mask, structure=np.ones((3,3)))
106
- boundary = dilated ^ eroded # XOR operation to find the difference between dilated and eroded mask
107
- # Loop over the boundary and paint the corresponding pixels
108
- for i in range(boundary.shape[0]):
109
- for j in range(boundary.shape[1]):
110
- if boundary[i, j]:
111
- # This is a pixel on the boundary, paint it red
112
- draw.point((i, j), fill=colors[0])
113
- else:
114
- NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
115
- return img
116
-
117
 
118
  def get_conv_log_filename():
119
  t = datetime.datetime.now()
@@ -121,7 +36,6 @@ def get_conv_log_filename():
121
  return name
122
 
123
 
124
- # TODO: return model manually just one for now called "jadechoghari/Ferret-UI-Gemma2b"
125
  def get_model_list():
126
  # ret = requests.post(args.controller_url + "/refresh_all_workers")
127
  # assert ret.status_code == 200
@@ -134,7 +48,6 @@ def get_model_list():
134
  logger.info(f"Models: {models}")
135
  return models
136
 
137
-
138
  get_window_url_params = """
139
  function() {
140
  const params = new URLSearchParams(window.location.search);
@@ -146,38 +59,25 @@ function() {
146
 
147
 
148
  def load_demo(url_params, request: gr.Request):
149
- # logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
150
 
151
- dropdown_update = gr.Dropdown.update(visible=True)
152
  if "model" in url_params:
153
  model = url_params["model"]
154
  if model in models:
155
- dropdown_update = gr.Dropdown.update(
156
- value=model, visible=True)
157
 
158
  state = default_conversation.copy()
159
- print("state", state)
160
- return (state,
161
- dropdown_update,
162
- gr.Chatbot.update(visible=True),
163
- gr.Textbox.update(visible=True),
164
- gr.Button.update(visible=True),
165
- gr.Row.update(visible=True),
166
- gr.Accordion.update(visible=True))
167
 
168
 
169
  def load_demo_refresh_model_list(request: gr.Request):
170
- # logger.info(f"load_demo. ip: {request.client.host}")
171
  models = get_model_list()
172
  state = default_conversation.copy()
173
- return (state, gr.Dropdown.update(
174
- choices=models,
175
- value=models[0] if len(models) > 0 else ""),
176
- gr.Chatbot.update(visible=True),
177
- gr.Textbox.update(visible=True),
178
- gr.Button.update(visible=True),
179
- gr.Row.update(visible=True),
180
- gr.Accordion.update(visible=True))
181
 
182
 
183
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
@@ -213,71 +113,15 @@ def regenerate(state, image_process_mode, request: gr.Request):
213
  if type(prev_human_msg[1]) in (tuple, list):
214
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
215
  state.skip_next = False
216
- return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
217
 
218
 
219
  def clear_history(request: gr.Request):
220
  state = default_conversation.copy()
221
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
222
- (None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)
223
-
224
-
225
- def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
226
- ratio_w = image_w * 1.0 / default_wh
227
- ratio_h = image_h * 1.0 / default_wh
228
-
229
- new_box = [int(box[0] * ratio_w), int(box[1] * ratio_h), \
230
- int(box[2] * ratio_w), int(box[3] * ratio_h)]
231
- return new_box
232
-
233
-
234
- def show_location(sketch_pad, chatbot):
235
- image = sketch_pad['image']
236
- img_w, img_h = image.size
237
- new_bboxes = []
238
- old_bboxes = []
239
- # chatbot[0] is image.
240
- text = chatbot[1:]
241
- for round_i in text:
242
- human_input = round_i[0]
243
- model_output = round_i[1]
244
- # TODO: Difference: vocab representation.
245
- # pattern = r'\[x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?), x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?)\]'
246
- pattern = r'\[(\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?)\]'
247
- matches = re.findall(pattern, model_output)
248
- for match in matches:
249
- x1, y1, x2, y2 = map(int, match)
250
- new_box = resize_bbox([x1, y1, x2, y2], img_w, img_h)
251
- new_bboxes.append(new_box)
252
- old_bboxes.append([x1, y1, x2, y2])
253
-
254
- set_old_bboxes = sorted(set(map(tuple, old_bboxes)), key=list(map(tuple, old_bboxes)).index)
255
- list_old_bboxes = list(map(list, set_old_bboxes))
256
-
257
- set_bboxes = sorted(set(map(tuple, new_bboxes)), key=list(map(tuple, new_bboxes)).index)
258
- list_bboxes = list(map(list, set_bboxes))
259
-
260
- output_image = deepcopy(image)
261
- draw = ImageDraw.Draw(output_image)
262
- #TODO: change from local to online path
263
- font = ImageFont.truetype("./DejaVuSans.ttf", 28)
264
- for i in range(len(list_bboxes)):
265
- x1, y1, x2, y2 = list_old_bboxes[i]
266
- x1_new, y1_new, x2_new, y2_new = list_bboxes[i]
267
- obj_string = '[obj{}]'.format(i)
268
- for round_i in text:
269
- model_output = round_i[1]
270
- model_output = model_output.replace('[{}, {}, {}, {}]'.format(x1, y1, x2, y2), obj_string)
271
- round_i[1] = model_output
272
- draw.rectangle([(x1_new, y1_new), (x2_new, y2_new)], outline="red", width=3)
273
- draw.text((x1_new+2, y1_new+5), obj_string[1:-1], fill="red", font=font)
274
-
275
- return (output_image, [chatbot[0]] + text, disable_btn)
276
-
277
-
278
- def add_text(state, text, image_process_mode, original_image, sketch_pad, request: gr.Request):
279
- image = sketch_pad['image']
280
 
 
 
281
  if len(text) <= 0 and image is None:
282
  state.skip_next = True
283
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
@@ -289,68 +133,20 @@ def add_text(state, text, image_process_mode, original_image, sketch_pad, reques
289
  no_change_btn,) * 5
290
 
291
  text = text[:1536] # Hard cut-off
292
- if original_image is None:
293
- assert image is not None
294
- original_image = image.copy()
295
- print('No location, copy original image in add_text')
296
-
297
  if image is not None:
298
- if state.first_round:
299
- text = text[:1200] # Hard cut-off for images
300
- if '<image>' not in text:
301
- # text = '<Image><image></Image>' + text
302
- text = text + '\n<image>'
303
- text = (text, original_image, image_process_mode)
304
- if len(state.get_images(return_pil=True)) > 0:
305
- new_state = default_conversation.copy()
306
- new_state.first_round = False
307
- state=new_state
308
- print('First round add image finsihed.')
309
-
310
  state.append_message(state.roles[0], text)
311
  state.append_message(state.roles[1], None)
312
  state.skip_next = False
313
- return (state, state.to_gradio_chatbot(), "", original_image) + (disable_btn,) * 5
314
-
315
-
316
- def post_process_code(code):
317
- sep = "\n```"
318
- if sep in code:
319
- blocks = code.split(sep)
320
- if len(blocks) % 2 == 1:
321
- for i in range(1, len(blocks), 2):
322
- blocks[i] = blocks[i].replace("\\_", "_")
323
- code = sep.join(blocks)
324
- return code
325
-
326
-
327
- def find_indices_in_order(str_list, STR):
328
- indices = []
329
- i = 0
330
- while i < len(STR):
331
- for element in str_list:
332
- if STR[i:i+len(element)] == element:
333
- indices.append(str_list.index(element))
334
- i += len(element) - 1
335
- break
336
- i += 1
337
- return indices
338
-
339
-
340
- def format_region_prompt(prompt, refer_input_state):
341
- # Find regions in prompts and assign corresponding region masks
342
- refer_input_state['region_masks_in_prompts'] = []
343
- indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
344
- refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]
345
-
346
- # Find regions in prompts and replace with real coordinates and region feature token.
347
- for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
348
- prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
349
- return prompt
350
-
351
  @spaces.GPU()
352
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_input_state, request: gr.Request):
353
- # def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
354
  start_tstamp = time.time()
355
  model_name = model_selector
356
 
@@ -359,42 +155,49 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
359
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
360
  return
361
 
362
- print("state messages: ", state.messages)
363
  if len(state.messages) == state.offset + 2:
364
  # First round of conversation
365
- # template_name = 'ferret_v1'
366
- template_name = 'ferret_gemma_instruct'
367
- # Below is LLaVA's original templates.
368
- # if "llava" in model_name.lower():
369
- # if 'llama-2' in model_name.lower():
370
- # template_name = "llava_llama_2"
371
- # elif "v1" in model_name.lower():
372
- # if 'mmtag' in model_name.lower():
373
- # template_name = "v1_mmtag"
374
- # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
375
- # template_name = "v1_mmtag"
376
- # else:
377
- # template_name = "llava_v1"
378
- # elif "mpt" in model_name.lower():
379
- # template_name = "mpt"
380
- # else:
381
- # if 'mmtag' in model_name.lower():
382
- # template_name = "v0_mmtag"
383
- # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
384
- # template_name = "v0_mmtag"
385
- # else:
386
- # template_name = "llava_v0"
387
- # elif "mpt" in model_name:
388
- # template_name = "mpt_text"
389
- # elif "llama-2" in model_name:
390
- # template_name = "llama_2"
391
- # else:
392
- # template_name = "vicuna_v1"
 
 
 
 
 
 
 
 
 
393
  new_state = conv_templates[template_name].copy()
394
  new_state.append_message(new_state.roles[0], state.messages[-2][1])
395
  new_state.append_message(new_state.roles[1], None)
396
  state = new_state
397
- state.first_round = False
398
 
399
  # # Query worker address
400
  # controller_url = args.controller_url
@@ -403,7 +206,9 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
403
  # worker_addr = ret.json()["address"]
404
  # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
405
 
406
- # No available worker
 
 
407
  # if worker_addr == "":
408
  # state.messages[-1][-1] = server_error_msg
409
  # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
@@ -411,15 +216,14 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
411
 
412
  # Construct prompt
413
  prompt = state.get_prompt()
414
- if args.add_region_feature:
415
- prompt = format_region_prompt(prompt, refer_input_state)
416
 
417
  all_images = state.get_images(return_pil=True)
418
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
419
  for image, hash in zip(all_images, all_image_hash):
420
  t = datetime.datetime.now()
421
- # fishy can remove it
422
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
 
423
  if not os.path.isfile(filename):
424
  os.makedirs(os.path.dirname(filename), exist_ok=True)
425
  image.save(filename)
@@ -435,23 +239,21 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
435
  "images": f'List of {len(state.get_images())} images: {all_image_hash}',
436
  }
437
  logger.info(f"==== request ====\n{pload}")
438
- if args.add_region_feature:
439
- pload['region_masks'] = refer_input_state['region_masks_in_prompts']
440
- logger.info(f"==== add region_masks_in_prompts to request ====\n")
441
 
442
  pload['images'] = state.get_images()
443
- print(f'Input Prompt: {prompt}')
444
- print("all_image_hash", all_image_hash)
445
 
446
  state.messages[-1][-1] = "▌"
447
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
448
-
449
  try:
450
  # Stream output
 
 
451
  stop = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
452
  #TODO: define inference and run function
453
  results, extracted_texts = inference_and_run(
454
- image_path=all_image_hash[0], # double check this
 
455
  prompt=prompt,
456
  model_path=model_name,
457
  conv_mode="ferret_gemma_instruct", # Default mode from the original function
@@ -460,18 +262,13 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
460
  max_new_tokens=max_new_tokens,
461
  stop=stop # Assuming we want to process the image
462
  )
463
-
464
- # response = requests.post(worker_addr + "/worker_generate_stream",
465
- # headers=headers, json=pload, stream=True, timeout=10)
466
  response = extracted_texts
467
  logger.info(f"This is the respone {response}")
468
-
469
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
470
  if chunk:
471
  data = json.loads(chunk.decode())
472
  if data["error_code"] == 0:
473
  output = data["text"][len(prompt):].strip()
474
- output = post_process_code(output)
475
  state.messages[-1][-1] = output + "▌"
476
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
477
  else:
@@ -497,7 +294,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
497
  "type": "chat",
498
  "model": model_name,
499
  "start": round(start_tstamp, 4),
500
- "finish": round(start_tstamp, 4),
501
  "state": state.dict(),
502
  "images": all_image_hash,
503
  "ip": request.client.host,
@@ -505,142 +302,45 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
505
  fout.write(json.dumps(data) + "\n")
506
 
507
  title_markdown = ("""
508
- # 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
 
509
  """)
510
- # [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)
511
 
512
  tos_markdown = ("""
513
  ### Terms of use
514
- By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
 
 
 
515
  """)
516
 
517
 
518
  learn_more_markdown = ("""
519
  ### License
520
- The service is a research preview intended for non-commercial use only
521
  """)
522
 
 
523
 
524
- css = code_highlight_css + """
525
- pre {
526
- white-space: pre-wrap; /* Since CSS 2.1 */
527
- white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
528
- white-space: -pre-wrap; /* Opera 4-6 */
529
- white-space: -o-pre-wrap; /* Opera 7 */
530
- word-wrap: break-word; /* Internet Explorer 5.5+ */
531
  }
 
532
  """
533
 
534
- Instructions = '''
535
- Instructions:
536
- 1. Select a 'Referring Input Type'
537
- 2. Draw on the image to refer to a region/point.
538
- 3. Copy the region id from 'Referring Input Type' to refer to a region in your chat.
539
- '''
540
-
541
- class ImageMask(gr.components.Image):
542
- """
543
- Sets: source="canvas", tool="sketch"
544
- """
545
-
546
- is_template = True
547
-
548
- def __init__(self, **kwargs):
549
- super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
550
-
551
- def preprocess(self, x):
552
- return super().preprocess(x)
553
-
554
-
555
- def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
556
- if type(input) == dict:
557
- image = deepcopy(input['image'])
558
- mask = deepcopy(input['mask'])
559
- else:
560
- mask = deepcopy(input)
561
-
562
- # W, H -> H, W, 3
563
- image_new = np.asarray(image)
564
- img_height = image_new.shape[0]
565
- img_width = image_new.shape[1]
566
-
567
- # W, H, 4 -> H, W
568
- mask_new = np.asarray(mask)[:,:,0].copy()
569
- mask_new = torch.from_numpy(mask_new)
570
- mask_new = (F.interpolate(mask_new.unsqueeze(0).unsqueeze(0), (img_height, img_width), mode='bilinear') > 0)
571
- mask_new = mask_new[0, 0].transpose(1, 0).long()
572
-
573
- if len(refer_input_state['masks']) == 0:
574
- last_mask = torch.zeros_like(mask_new)
575
- else:
576
- last_mask = refer_input_state['masks'][-1]
577
-
578
- diff_mask = mask_new - last_mask
579
- if torch.all(diff_mask == 0):
580
- print('Init Uploading Images.')
581
- return (refer_input_state, refer_text_show, image)
582
- else:
583
- refer_input_state['masks'].append(mask_new)
584
-
585
- if input_mode == 'Point':
586
- nonzero_points = diff_mask.nonzero()
587
- nonzero_points_avg_x = torch.median(nonzero_points[:, 0])
588
- nonzero_points_avg_y = torch.median(nonzero_points[:, 1])
589
- sampled_coor = [nonzero_points_avg_x, nonzero_points_avg_y]
590
- # pdb.set_trace()
591
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
592
- elif input_mode == 'Box' or input_mode == 'Sketch':
593
- # pdb.set_trace()
594
- x1x2 = diff_mask.max(1)[0].nonzero()[:, 0]
595
- y1y2 = diff_mask.max(0)[0].nonzero()[:, 0]
596
- y1, y2 = y1y2.min(), y1y2.max()
597
- x1, x2 = x1x2.min(), x1x2.max()
598
- # pdb.set_trace()
599
- sampled_coor = [x1, y1, x2, y2]
600
- if input_mode == 'Box':
601
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
602
- else:
603
- cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height, mask=diff_mask)
604
- else:
605
- raise NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
606
-
607
- # TODO(haoxuan): Hack img_size to be 224 here, need to make it a argument.
608
- if len(sampled_coor) == 2:
609
- point_x = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
610
- point_y = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
611
- cur_region_coordinates = f'[{int(point_x)}, {int(point_y)}]'
612
- elif len(sampled_coor) == 4:
613
- point_x1 = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
614
- point_y1 = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
615
- point_x2 = int(VOCAB_IMAGE_W * sampled_coor[2] / img_width)
616
- point_y2 = int(VOCAB_IMAGE_H * sampled_coor[3] / img_height)
617
- cur_region_coordinates = f'[{int(point_x1)}, {int(point_y1)}, {int(point_x2)}, {int(point_y2)}]'
618
-
619
- cur_region_id = len(refer_input_state['region_placeholder_tokens'])
620
- cur_region_token = DEFAULT_REGION_REFER_TOKEN.split(']')[0] + str(cur_region_id) + ']'
621
- refer_input_state['region_placeholder_tokens'].append(cur_region_token)
622
- refer_input_state['region_coordinates'].append(cur_region_coordinates)
623
- refer_input_state['region_masks'].append(cur_region_masks)
624
- assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
625
- refer_text_show.append((cur_region_token, ''))
626
-
627
- # Show Parsed Referring.
628
- imagebox_refer = draw_box(sampled_coor, cur_region_masks, \
629
- cur_region_token, imagebox_refer, input_mode)
630
-
631
- return (refer_input_state, refer_text_show, imagebox_refer)
632
-
633
- def build_demo(embed_mode):
634
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False)
635
- with gr.Blocks(title="FERRET", theme=gr.themes.Base(), css=css) as demo:
636
  state = gr.State()
637
 
638
  if not embed_mode:
639
  gr.Markdown(title_markdown)
640
- gr.Markdown(Instructions)
641
 
642
  with gr.Row():
643
- with gr.Column(scale=4):
 
 
 
644
  with gr.Row(elem_id="model_selector_row"):
645
  model_selector = gr.Dropdown(
646
  choices=models,
@@ -649,65 +349,43 @@ def build_demo(embed_mode):
649
  show_label=False,
650
  container=False)
651
 
652
- original_image = gr.Image(type="pil", visible=False)
653
  image_process_mode = gr.Radio(
654
- ["Raw+Processor", "Crop", "Resize", "Pad"],
655
- value="Raw+Processor",
656
- label="Preprocess for non-square image",
657
- visible=False)
658
-
659
- # Added for any-format input.
660
- sketch_pad = ImageMask(label="Image & Sketch", type="pil", elem_id="img2text")
661
- refer_input_mode = gr.Radio(
662
- ["Point", "Box", "Sketch"],
663
- value="Point",
664
- label="Referring Input Type")
665
- refer_input_state = gr.State({'region_placeholder_tokens':[],
666
- 'region_coordinates':[],
667
- 'region_masks':[],
668
- 'region_masks_in_prompts':[],
669
- 'masks':[],
670
- })
671
- refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")
672
-
673
- imagebox_refer = gr.Image(type="pil", label="Parsed Referring Input")
674
- imagebox_output = gr.Image(type="pil", label='Output Vis')
675
-
676
- cur_dir = os.path.dirname(os.path.abspath(__file__))
677
- # gr.Examples(examples=[
678
- # # [f"{cur_dir}/examples/harry-potter-hogwarts.jpg", "What is in [region0]? And what do people use it for?"],
679
- # # [f"{cur_dir}/examples/ingredients.jpg", "What objects are in [region0] and [region1]?"],
680
- # # [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image? And tell me the coordinates of mentioned objects."],
681
- # [f"{cur_dir}/examples/ferret.jpg", "What's the relationship between object [region0] and object [region1]?"],
682
- # [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here? Tell me the coordinates in response."],
683
- # [f"{cur_dir}/examples/flickr_9472793441.jpg", "Describe the image in details."],
684
- # # [f"{cur_dir}/examples/coco_000000281759.jpg", "What are the locations of the woman wearing a blue dress, the woman in flowery top, the girl in purple dress, the girl wearing green shirt?"],
685
- # [f"{cur_dir}/examples/room_planning.jpg", "How to improve the design of the given room?"],
686
- # [f"{cur_dir}/examples/make_sandwitch.jpg", "How can I make a sandwich with available ingredients?"],
687
- # [f"{cur_dir}/examples/bathroom.jpg", "What is unusual about this image?"],
688
- # [f"{cur_dir}/examples/kitchen.png", "Is the object a man or a chicken? Explain the reason."],
689
- # ], inputs=[sketch_pad, textbox])
690
-
691
- with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
692
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
693
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
694
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
695
 
696
- with gr.Column(scale=5):
697
- chatbot = gr.Chatbot(elem_id="chatbot", label="FERRET", visible=False).style(height=750)
 
 
 
 
 
698
  with gr.Row():
699
  with gr.Column(scale=8):
700
  textbox.render()
701
- with gr.Column(scale=1, min_width=60):
702
- submit_btn = gr.Button(value="Submit", visible=False)
703
- with gr.Row(visible=False) as button_row:
704
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
705
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
706
- # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
707
  #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
708
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
709
- clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
710
- location_btn = gr.Button(value="🪄 Show location", interactive=False)
711
 
712
  if not embed_mode:
713
  gr.Markdown(tos_markdown)
@@ -715,69 +393,124 @@ def build_demo(embed_mode):
715
  url_params = gr.JSON(visible=False)
716
 
717
  # Register listeners
718
- btn_list = [upvote_btn, downvote_btn, location_btn, regenerate_btn, clear_btn]
719
- upvote_btn.click(upvote_last_response,
720
- [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
721
- downvote_btn.click(downvote_last_response,
722
- [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
723
- # flag_btn.click(flag_last_response,
724
- # [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
725
- regenerate_btn.click(regenerate, [state, image_process_mode],
726
- [state, chatbot, textbox] + btn_list).then(
727
- http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
728
- [state, chatbot] + btn_list)
729
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox_output, original_image] + btn_list + \
730
- [sketch_pad, refer_input_state, refer_text_show, imagebox_refer])
731
- location_btn.click(show_location,
732
- [sketch_pad, chatbot], [imagebox_output, chatbot, location_btn])
733
-
734
- textbox.submit(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
735
- ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
736
- [state, chatbot] + btn_list)
737
-
738
- submit_btn.click(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
739
- ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
740
- [state, chatbot] + btn_list)
741
-
742
- sketch_pad.edit(
743
- draw,
744
- inputs=[refer_input_mode, sketch_pad, refer_input_state, refer_text_show, imagebox_refer],
745
- outputs=[refer_input_state, refer_text_show, imagebox_refer],
746
- queue=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  )
748
 
749
  if args.model_list_mode == "once":
750
- demo.load(load_demo, [url_params], [state, model_selector,
751
- chatbot, textbox, submit_btn, button_row, parameter_row],
752
- _js=get_window_url_params)
 
 
 
753
  elif args.model_list_mode == "reload":
754
- demo.load(load_demo_refresh_model_list, None, [state, model_selector,
755
- chatbot, textbox, submit_btn, button_row, parameter_row])
 
 
 
 
756
  else:
757
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
758
 
759
  return demo
760
 
761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
  if __name__ == "__main__":
763
  parser = argparse.ArgumentParser()
764
  parser.add_argument("--host", type=str, default="0.0.0.0")
765
  parser.add_argument("--port", type=int)
766
  parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
767
- parser.add_argument("--concurrency-count", type=int, default=8)
768
  parser.add_argument("--model-list-mode", type=str, default="once",
769
  choices=["once", "reload"])
770
  parser.add_argument("--share", action="store_true")
771
  parser.add_argument("--moderate", action="store_true")
772
  parser.add_argument("--embed", action="store_true")
773
- parser.add_argument("--add_region_feature", action="store_true")
774
  args = parser.parse_args()
775
  logger.info(f"args: {args}")
776
 
777
  models = get_model_list()
778
 
779
  logger.info(args)
780
- demo = build_demo(args.embed)
781
- demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
782
- api_open=False).launch(
783
- server_name=args.host, server_port=args.port, share=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import datetime
3
  import json
 
6
 
7
  import gradio as gr
8
  import requests
9
+ from inference import inference_and_run
10
  from conversation import (default_conversation, conv_templates,
11
  SeparatorStyle)
12
+
13
+ LOGDIR = "."
14
  from utils import (build_logger, server_error_msg,
15
+ violates_moderation, moderation_msg)
16
  import hashlib
 
 
 
 
 
 
 
 
 
 
 
17
  import spaces
18
 
 
 
 
 
 
 
19
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
20
 
21
+ headers = {"User-Agent": "LLaVA Client"}
22
 
23
+ no_change_btn = gr.Button()
24
+ enable_btn = gr.Button(interactive=True)
25
+ disable_btn = gr.Button(interactive=False)
26
 
27
  priority = {
28
  "vicuna-13b": "aaaaaaa",
29
  "koala-13b": "aaaaaab",
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def get_conv_log_filename():
34
  t = datetime.datetime.now()
 
36
  return name
37
 
38
 
 
39
  def get_model_list():
40
  # ret = requests.post(args.controller_url + "/refresh_all_workers")
41
  # assert ret.status_code == 200
 
48
  logger.info(f"Models: {models}")
49
  return models
50
 
 
51
  get_window_url_params = """
52
  function() {
53
  const params = new URLSearchParams(window.location.search);
 
59
 
60
 
61
  def load_demo(url_params, request: gr.Request):
 
62
 
63
+ dropdown_update = gr.Dropdown(visible=True)
64
  if "model" in url_params:
65
  model = url_params["model"]
66
  if model in models:
67
+ dropdown_update = gr.Dropdown(value=model, visible=True)
 
68
 
69
  state = default_conversation.copy()
70
+ return state, dropdown_update
 
 
 
 
 
 
 
71
 
72
 
73
  def load_demo_refresh_model_list(request: gr.Request):
 
74
  models = get_model_list()
75
  state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
 
 
 
81
 
82
 
83
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
 
113
  if type(prev_human_msg[1]) in (tuple, list):
114
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
115
  state.skip_next = False
116
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
117
 
118
 
119
  def clear_history(request: gr.Request):
120
  state = default_conversation.copy()
121
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+
124
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
125
  if len(text) <= 0 and image is None:
126
  state.skip_next = True
127
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
 
133
  no_change_btn,) * 5
134
 
135
  text = text[:1536] # Hard cut-off
 
 
 
 
 
136
  if image is not None:
137
+ text = text[:1200] # Hard cut-off for images
138
+ if '<image>' not in text:
139
+ # text = '<Image><image></Image>' + text
140
+ text = text + '\n<image>'
141
+ text = (text, image, image_process_mode)
142
+ state = default_conversation.copy()
 
 
 
 
 
 
143
  state.append_message(state.roles[0], text)
144
  state.append_message(state.roles[1], None)
145
  state.skip_next = False
146
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
147
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @spaces.GPU()
149
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
 
150
  start_tstamp = time.time()
151
  model_name = model_selector
152
 
 
155
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
156
  return
157
 
 
158
  if len(state.messages) == state.offset + 2:
159
  # First round of conversation
160
+ if "llava" in model_name.lower():
161
+ if 'llama-2' in model_name.lower():
162
+ template_name = "llava_llama_2"
163
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
164
+ if 'orca' in model_name.lower():
165
+ template_name = "mistral_orca"
166
+ elif 'hermes' in model_name.lower():
167
+ template_name = "chatml_direct"
168
+ else:
169
+ template_name = "mistral_instruct"
170
+ elif 'llava-v1.6-34b' in model_name.lower():
171
+ template_name = "chatml_direct"
172
+ elif "v1" in model_name.lower():
173
+ if 'mmtag' in model_name.lower():
174
+ template_name = "v1_mmtag"
175
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
176
+ template_name = "v1_mmtag"
177
+ else:
178
+ template_name = "llava_v1"
179
+ elif "mpt" in model_name.lower():
180
+ template_name = "mpt"
181
+ else:
182
+ if 'mmtag' in model_name.lower():
183
+ template_name = "v0_mmtag"
184
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
185
+ template_name = "v0_mmtag"
186
+ else:
187
+ template_name = "llava_v0"
188
+ elif "mpt" in model_name:
189
+ template_name = "mpt_text"
190
+ elif "llama-2" in model_name:
191
+ template_name = "llama_2"
192
+ elif "gemma" in model_name.lower():
193
+ template_name = "ferret_gemma_instruct"
194
+ print("conv mode to gemma")
195
+ else:
196
+ template_name = "vicuna_v1"
197
  new_state = conv_templates[template_name].copy()
198
  new_state.append_message(new_state.roles[0], state.messages[-2][1])
199
  new_state.append_message(new_state.roles[1], None)
200
  state = new_state
 
201
 
202
  # # Query worker address
203
  # controller_url = args.controller_url
 
206
  # worker_addr = ret.json()["address"]
207
  # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
208
 
209
+
210
+
211
+ # # No available worker
212
  # if worker_addr == "":
213
  # state.messages[-1][-1] = server_error_msg
214
  # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
216
 
217
  # Construct prompt
218
  prompt = state.get_prompt()
 
 
219
 
220
  all_images = state.get_images(return_pil=True)
221
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
222
  for image, hash in zip(all_images, all_image_hash):
223
  t = datetime.datetime.now()
224
+ dir_path = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}")
225
+ # filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
226
+ filename = os.path.join(dir_path, f"{hash}.jpg")
227
  if not os.path.isfile(filename):
228
  os.makedirs(os.path.dirname(filename), exist_ok=True)
229
  image.save(filename)
 
239
  "images": f'List of {len(state.get_images())} images: {all_image_hash}',
240
  }
241
  logger.info(f"==== request ====\n{pload}")
 
 
 
242
 
243
  pload['images'] = state.get_images()
 
 
244
 
245
  state.messages[-1][-1] = "▌"
246
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
247
+
248
  try:
249
  # Stream output
250
+ # response = requests.post(worker_addr + "/worker_generate_stream",
251
+ # headers=headers, json=pload, stream=True, timeout=10)
252
  stop = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
253
  #TODO: define inference and run function
254
  results, extracted_texts = inference_and_run(
255
+ image_path=filename, # double check this
256
+ image_dir=dir_path,
257
  prompt=prompt,
258
  model_path=model_name,
259
  conv_mode="ferret_gemma_instruct", # Default mode from the original function
 
262
  max_new_tokens=max_new_tokens,
263
  stop=stop # Assuming we want to process the image
264
  )
 
 
 
265
  response = extracted_texts
266
  logger.info(f"This is the respone {response}")
 
267
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
268
  if chunk:
269
  data = json.loads(chunk.decode())
270
  if data["error_code"] == 0:
271
  output = data["text"][len(prompt):].strip()
 
272
  state.messages[-1][-1] = output + "▌"
273
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
274
  else:
 
294
  "type": "chat",
295
  "model": model_name,
296
  "start": round(start_tstamp, 4),
297
+ "finish": round(finish_tstamp, 4),
298
  "state": state.dict(),
299
  "images": all_image_hash,
300
  "ip": request.client.host,
 
302
  fout.write(json.dumps(data) + "\n")
303
 
304
  title_markdown = ("""
305
+ # 🌋 LLaVA: Large Language and Vision Assistant
306
+ [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)]
307
  """)
 
308
 
309
  tos_markdown = ("""
310
  ### Terms of use
311
+ By using this service, users are required to agree to the following terms:
312
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
313
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
314
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
315
  """)
316
 
317
 
318
  learn_more_markdown = ("""
319
  ### License
320
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
321
  """)
322
 
323
+ block_css = """
324
 
325
+ #buttons button {
326
+ min-width: min(120px,100%);
 
 
 
 
 
327
  }
328
+
329
  """
330
 
331
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
332
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
333
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  state = gr.State()
335
 
336
  if not embed_mode:
337
  gr.Markdown(title_markdown)
 
338
 
339
  with gr.Row():
340
+ models = [
341
+ "jadechoghari/Ferret-UI-Gemma2b"
342
+ ]
343
+ with gr.Column(scale=3):
344
  with gr.Row(elem_id="model_selector_row"):
345
  model_selector = gr.Dropdown(
346
  choices=models,
 
349
  show_label=False,
350
  container=False)
351
 
352
+ imagebox = gr.Image(type="pil")
353
  image_process_mode = gr.Radio(
354
+ ["Crop", "Resize", "Pad", "Default"],
355
+ value="Default",
356
+ label="Preprocess for non-square image", visible=False)
357
+
358
+ if cur_dir is None:
359
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
360
+ gr.Examples(examples=[
361
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
362
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
363
+ ], inputs=[imagebox, textbox])
364
+
365
+ with gr.Accordion("Parameters", open=False) as parameter_row:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
367
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
368
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
369
 
370
+ with gr.Column(scale=8):
371
+ chatbot = gr.Chatbot(
372
+ elem_id="chatbot",
373
+ label="LLaVA Chatbot",
374
+ height=650,
375
+ layout="panel",
376
+ )
377
  with gr.Row():
378
  with gr.Column(scale=8):
379
  textbox.render()
380
+ with gr.Column(scale=1, min_width=50):
381
+ submit_btn = gr.Button(value="Send", variant="primary")
382
+ with gr.Row(elem_id="buttons") as button_row:
383
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
384
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
385
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
386
  #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
387
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
388
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
 
389
 
390
  if not embed_mode:
391
  gr.Markdown(tos_markdown)
 
393
  url_params = gr.JSON(visible=False)
394
 
395
  # Register listeners
396
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
397
+ upvote_btn.click(
398
+ upvote_last_response,
399
+ [state, model_selector],
400
+ [textbox, upvote_btn, downvote_btn, flag_btn]
401
+ )
402
+ downvote_btn.click(
403
+ downvote_last_response,
404
+ [state, model_selector],
405
+ [textbox, upvote_btn, downvote_btn, flag_btn]
406
+ )
407
+ flag_btn.click(
408
+ flag_last_response,
409
+ [state, model_selector],
410
+ [textbox, upvote_btn, downvote_btn, flag_btn]
411
+ )
412
+
413
+ regenerate_btn.click(
414
+ regenerate,
415
+ [state, image_process_mode],
416
+ [state, chatbot, textbox, imagebox] + btn_list
417
+ ).then(
418
+ http_bot,
419
+ [state, model_selector, temperature, top_p, max_output_tokens],
420
+ [state, chatbot] + btn_list,
421
+ concurrency_limit=concurrency_count
422
+ )
423
+
424
+ clear_btn.click(
425
+ clear_history,
426
+ None,
427
+ [state, chatbot, textbox, imagebox] + btn_list,
428
+ queue=False
429
+ )
430
+
431
+ textbox.submit(
432
+ add_text,
433
+ [state, textbox, imagebox, image_process_mode],
434
+ [state, chatbot, textbox, imagebox] + btn_list,
435
+ queue=False
436
+ ).then(
437
+ http_bot,
438
+ [state, model_selector, temperature, top_p, max_output_tokens],
439
+ [state, chatbot] + btn_list,
440
+ concurrency_limit=concurrency_count
441
+ )
442
+
443
+ submit_btn.click(
444
+ add_text,
445
+ [state, textbox, imagebox, image_process_mode],
446
+ [state, chatbot, textbox, imagebox] + btn_list
447
+ ).then(
448
+ http_bot,
449
+ [state, model_selector, temperature, top_p, max_output_tokens],
450
+ [state, chatbot] + btn_list,
451
+ concurrency_limit=concurrency_count
452
  )
453
 
454
  if args.model_list_mode == "once":
455
+ demo.load(
456
+ load_demo,
457
+ [url_params],
458
+ [state, model_selector],
459
+ js=get_window_url_params
460
+ )
461
  elif args.model_list_mode == "reload":
462
+ demo.load(
463
+ load_demo_refresh_model_list,
464
+ None,
465
+ [state, model_selector],
466
+ queue=False
467
+ )
468
  else:
469
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
470
 
471
  return demo
472
 
473
 
474
+ # if __name__ == "__main__":
475
+ # parser = argparse.ArgumentParser()
476
+ # parser.add_argument("--port", type=int, default=7860) # You can still specify the port
477
+ # parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
478
+ # parser.add_argument("--concurrency-count", type=int, default=16)
479
+ # parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
480
+ # parser.add_argument("--share", action="store_true")
481
+ # parser.add_argument("--moderate", action="store_true")
482
+ # parser.add_argument("--embed", action="store_true")
483
+ # args = parser.parse_args()
484
+ # # models = get_model_list()
485
+ # demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
486
+ # demo.queue(api_open=False).launch(
487
+ # server_port=args.port, # Specify the port if needed
488
+ # share=True,
489
+ # debug=True # All other functionalities like sharing still work
490
+ # )
491
  if __name__ == "__main__":
492
  parser = argparse.ArgumentParser()
493
  parser.add_argument("--host", type=str, default="0.0.0.0")
494
  parser.add_argument("--port", type=int)
495
  parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
496
+ parser.add_argument("--concurrency-count", type=int, default=16)
497
  parser.add_argument("--model-list-mode", type=str, default="once",
498
  choices=["once", "reload"])
499
  parser.add_argument("--share", action="store_true")
500
  parser.add_argument("--moderate", action="store_true")
501
  parser.add_argument("--embed", action="store_true")
 
502
  args = parser.parse_args()
503
  logger.info(f"args: {args}")
504
 
505
  models = get_model_list()
506
 
507
  logger.info(args)
508
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
509
+ demo.queue(
510
+ api_open=False
511
+ ).launch(
512
+ server_name=args.host,
513
+ server_port=args.port,
514
+ share=True,
515
+ debug=True
516
+ )
app.pyi ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Usage:
3
+
4
+ python -m ferret.serve.gradio_web_server --controller http://localhost:10000 --add_region_feature
5
+ '''
6
+ import argparse
7
+ import datetime
8
+ import json
9
+ import os
10
+ import time
11
+
12
+ import gradio as gr
13
+ import requests
14
+
15
+ from conversation import (default_conversation, conv_templates,
16
+ SeparatorStyle)
17
+ from constants import LOGDIR
18
+ from utils import (build_logger, server_error_msg,
19
+ violates_moderation, moderation_msg)
20
+ import hashlib
21
+ # Added
22
+ import re
23
+ from copy import deepcopy
24
+ from PIL import ImageDraw, ImageFont
25
+ from gradio import processing_utils
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from scipy.ndimage import binary_dilation, binary_erosion
30
+ import pdb
31
+ from gradio_css import code_highlight_css
32
+ import spaces
33
+
34
+ from inference import inference_and_run
35
+
36
+ DEFAULT_REGION_REFER_TOKEN = "[region]"
37
+ DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
38
+
39
+
40
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
41
+
42
+ headers = {"User-Agent": "FERRET Client"}
43
+
44
+ no_change_btn = gr.Button
45
+ enable_btn = gr.Button(interactive=True)
46
+ disable_btn = gr.Button(interactive=False)
47
+
48
+ priority = {
49
+ "vicuna-13b": "aaaaaaa",
50
+ "koala-13b": "aaaaaab",
51
+ }
52
+
53
+ VOCAB_IMAGE_W = 1000 # 224
54
+ VOCAB_IMAGE_H = 1000 # 224
55
+
56
+ def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
57
+ if mask is not None:
58
+ assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
59
+ coor_mask = torch.zeros((raw_w, raw_h))
60
+ # Assume it samples a point.
61
+ if len(coor) == 2:
62
+ # Define window size
63
+ span = 5
64
+ # Make sure the window does not exceed array bounds
65
+ x_min = max(0, coor[0] - span)
66
+ x_max = min(raw_w, coor[0] + span + 1)
67
+ y_min = max(0, coor[1] - span)
68
+ y_max = min(raw_h, coor[1] + span + 1)
69
+ coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
70
+ assert (coor_mask==1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
71
+ elif len(coor) == 4:
72
+ # Box input or Sketch input.
73
+ coor_mask = torch.zeros((raw_w, raw_h))
74
+ coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
75
+ if mask is not None:
76
+ coor_mask = coor_mask * mask
77
+ # coor_mask = torch.from_numpy(coor_mask)
78
+ # pdb.set_trace()
79
+ assert len(coor_mask.nonzero()) != 0
80
+ return coor_mask.tolist()
81
+
82
+
83
+ def draw_box(coor, region_mask, region_ph, img, input_mode):
84
+ colors = ["red"]
85
+ draw = ImageDraw.Draw(img)
86
+ font = ImageFont.truetype("./DejaVuSans.ttf", size=18)
87
+ if input_mode == 'Box':
88
+ draw.rectangle([coor[0], coor[1], coor[2], coor[3]], outline=colors[0], width=4)
89
+ draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
90
+ draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
91
+ elif input_mode == 'Point':
92
+ r = 8
93
+ leftUpPoint = (coor[0]-r, coor[1]-r)
94
+ rightDownPoint = (coor[0]+r, coor[1]+r)
95
+ twoPointList = [leftUpPoint, rightDownPoint]
96
+ draw.ellipse(twoPointList, outline=colors[0], width=4)
97
+ draw.rectangle([coor[0], coor[1], coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[1] + int(font.size * 1.2)], outline=colors[0], fill=colors[0], width=4)
98
+ draw.text([coor[0] + int(font.size * 0.2), coor[1]], region_ph, font=font, fill=(255,255,255))
99
+ elif input_mode == 'Sketch':
100
+ draw.rectangle([coor[0], coor[3] - int(font.size * 1.2), coor[0] + int((len(region_ph) + 0.8) * font.size * 0.6), coor[3]], outline=colors[0], fill=colors[0], width=4)
101
+ draw.text([coor[0] + int(font.size * 0.2), coor[3] - int(font.size*1.2)], region_ph, font=font, fill=(255,255,255))
102
+ # Use morphological operations to find the boundary
103
+ mask = np.array(region_mask)
104
+ dilated = binary_dilation(mask, structure=np.ones((3,3)))
105
+ eroded = binary_erosion(mask, structure=np.ones((3,3)))
106
+ boundary = dilated ^ eroded # XOR operation to find the difference between dilated and eroded mask
107
+ # Loop over the boundary and paint the corresponding pixels
108
+ for i in range(boundary.shape[0]):
109
+ for j in range(boundary.shape[1]):
110
+ if boundary[i, j]:
111
+ # This is a pixel on the boundary, paint it red
112
+ draw.point((i, j), fill=colors[0])
113
+ else:
114
+ NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
115
+ return img
116
+
117
+
118
+ def get_conv_log_filename():
119
+ t = datetime.datetime.now()
120
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
121
+ return name
122
+
123
+
124
+ # TODO: return model manually just one for now called "jadechoghari/Ferret-UI-Gemma2b"
125
+ def get_model_list():
126
+ # ret = requests.post(args.controller_url + "/refresh_all_workers")
127
+ # assert ret.status_code == 200
128
+ # ret = requests.post(args.controller_url + "/list_models")
129
+ # models = ret.json()["models"]
130
+ # models.sort(key=lambda x: priority.get(x, x))
131
+ # logger.info(f"Models: {models}")
132
+ # return models
133
+ models = ["jadechoghari/Ferret-UI-Gemma2b"]
134
+ logger.info(f"Models: {models}")
135
+ return models
136
+
137
+
138
+ get_window_url_params = """
139
+ function() {
140
+ const params = new URLSearchParams(window.location.search);
141
+ url_params = Object.fromEntries(params);
142
+ console.log(url_params);
143
+ return url_params;
144
+ }
145
+ """
146
+
147
+
148
+ def load_demo(url_params, request: gr.Request):
149
+ # logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
150
+
151
+ dropdown_update = gr.Dropdown(visible=True)
152
+ if "model" in url_params:
153
+ model = url_params["model"]
154
+ if model in models:
155
+ dropdown_update = gr.Dropdown(
156
+ value=model, visible=True)
157
+
158
+ state = default_conversation.copy()
159
+ print("state", state)
160
+ return (state,
161
+ dropdown_update,
162
+ gr.Chatbot(visible=True),
163
+ gr.Textbox(visible=True),
164
+ gr.Button(visible=True),
165
+ gr.Row(visible=True),
166
+ gr.Accordion(visible=True))
167
+
168
+
169
+ def load_demo_refresh_model_list(request: gr.Request):
170
+ # logger.info(f"load_demo. ip: {request.client.host}")
171
+ models = get_model_list()
172
+ state = default_conversation.copy()
173
+ return (state, gr.Dropdown(
174
+ choices=models,
175
+ value=models[0] if len(models) > 0 else ""),
176
+ gr.Chatbot(visible=True),
177
+ gr.Textbox(visible=True),
178
+ gr.Button(visible=True),
179
+ gr.Row(visible=True),
180
+ gr.Accordion(visible=True))
181
+
182
+
183
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
184
+ with open(get_conv_log_filename(), "a") as fout:
185
+ data = {
186
+ "tstamp": round(time.time(), 4),
187
+ "type": vote_type,
188
+ "model": model_selector,
189
+ "state": state.dict(),
190
+ "ip": request.client.host,
191
+ }
192
+ fout.write(json.dumps(data) + "\n")
193
+
194
+
195
+ def upvote_last_response(state, model_selector, request: gr.Request):
196
+ vote_last_response(state, "upvote", model_selector, request)
197
+ return ("",) + (disable_btn,) * 3
198
+
199
+
200
+ def downvote_last_response(state, model_selector, request: gr.Request):
201
+ vote_last_response(state, "downvote", model_selector, request)
202
+ return ("",) + (disable_btn,) * 3
203
+
204
+
205
+ def flag_last_response(state, model_selector, request: gr.Request):
206
+ vote_last_response(state, "flag", model_selector, request)
207
+ return ("",) + (disable_btn,) * 3
208
+
209
+
210
+ def regenerate(state, image_process_mode, request: gr.Request):
211
+ state.messages[-1][-1] = None
212
+ prev_human_msg = state.messages[-2]
213
+ if type(prev_human_msg[1]) in (tuple, list):
214
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
215
+ state.skip_next = False
216
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
217
+
218
+
219
+ def clear_history(request: gr.Request):
220
+ state = default_conversation.copy()
221
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
222
+ (None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)
223
+
224
+
225
+ def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
226
+ ratio_w = image_w * 1.0 / default_wh
227
+ ratio_h = image_h * 1.0 / default_wh
228
+
229
+ new_box = [int(box[0] * ratio_w), int(box[1] * ratio_h), \
230
+ int(box[2] * ratio_w), int(box[3] * ratio_h)]
231
+ return new_box
232
+
233
+
234
+ def show_location(sketch_pad, chatbot):
235
+ image = sketch_pad['image']
236
+ img_w, img_h = image.size
237
+ new_bboxes = []
238
+ old_bboxes = []
239
+ # chatbot[0] is image.
240
+ text = chatbot[1:]
241
+ for round_i in text:
242
+ human_input = round_i[0]
243
+ model_output = round_i[1]
244
+ # TODO: Difference: vocab representation.
245
+ # pattern = r'\[x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?), x\d*=(\d+(?:\.\d+)?), y\d*=(\d+(?:\.\d+)?)\]'
246
+ pattern = r'\[(\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?), (\d+(?:\.\d+)?)\]'
247
+ matches = re.findall(pattern, model_output)
248
+ for match in matches:
249
+ x1, y1, x2, y2 = map(int, match)
250
+ new_box = resize_bbox([x1, y1, x2, y2], img_w, img_h)
251
+ new_bboxes.append(new_box)
252
+ old_bboxes.append([x1, y1, x2, y2])
253
+
254
+ set_old_bboxes = sorted(set(map(tuple, old_bboxes)), key=list(map(tuple, old_bboxes)).index)
255
+ list_old_bboxes = list(map(list, set_old_bboxes))
256
+
257
+ set_bboxes = sorted(set(map(tuple, new_bboxes)), key=list(map(tuple, new_bboxes)).index)
258
+ list_bboxes = list(map(list, set_bboxes))
259
+
260
+ output_image = deepcopy(image)
261
+ draw = ImageDraw.Draw(output_image)
262
+ #TODO: change from local to online path
263
+ font = ImageFont.truetype("./DejaVuSans.ttf", 28)
264
+ for i in range(len(list_bboxes)):
265
+ x1, y1, x2, y2 = list_old_bboxes[i]
266
+ x1_new, y1_new, x2_new, y2_new = list_bboxes[i]
267
+ obj_string = '[obj{}]'.format(i)
268
+ for round_i in text:
269
+ model_output = round_i[1]
270
+ model_output = model_output.replace('[{}, {}, {}, {}]'.format(x1, y1, x2, y2), obj_string)
271
+ round_i[1] = model_output
272
+ draw.rectangle([(x1_new, y1_new), (x2_new, y2_new)], outline="red", width=3)
273
+ draw.text((x1_new+2, y1_new+5), obj_string[1:-1], fill="red", font=font)
274
+
275
+ return (output_image, [chatbot[0]] + text, disable_btn)
276
+
277
+
278
+ def add_text(state, text, image_process_mode, original_image, sketch_pad, request: gr.Request):
279
+ print("add text called!")
280
+
281
+
282
+ image = sketch_pad['image']
283
+ print("text", text, "and : ", len(text))
284
+ print("Image path", original_image)
285
+
286
+ if len(text) <= 0 and image is None:
287
+ state.skip_next = True
288
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
289
+ if args.moderate:
290
+ flagged = violates_moderation(text)
291
+ if flagged:
292
+ state.skip_next = True
293
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
294
+ no_change_btn,) * 5
295
+
296
+ text = text[:1536] # Hard cut-off
297
+ if original_image is None:
298
+ assert image is not None
299
+ original_image = image.copy()
300
+ print('No location, copy original image in add_text')
301
+
302
+ if image is not None:
303
+ if state.first_round:
304
+ text = text[:1200] # Hard cut-off for images
305
+ if '<image>' not in text:
306
+ # text = '<Image><image></Image>' + text
307
+ text = text + '\n<image>'
308
+ text = (text, original_image, image_process_mode)
309
+ if len(state.get_images(return_pil=True)) > 0:
310
+ new_state = default_conversation.copy()
311
+ new_state.first_round = False
312
+ state=new_state
313
+ print('First round add image finsihed.')
314
+
315
+ state.append_message(state.roles[0], text)
316
+ state.append_message(state.roles[1], None)
317
+ state.skip_next = False
318
+ return (state, state.to_gradio_chatbot(), "", original_image) + (disable_btn,) * 5
319
+
320
+
321
+ def post_process_code(code):
322
+ sep = "\n```"
323
+ if sep in code:
324
+ blocks = code.split(sep)
325
+ if len(blocks) % 2 == 1:
326
+ for i in range(1, len(blocks), 2):
327
+ blocks[i] = blocks[i].replace("\\_", "_")
328
+ code = sep.join(blocks)
329
+ return code
330
+
331
+
332
+ def find_indices_in_order(str_list, STR):
333
+ indices = []
334
+ i = 0
335
+ while i < len(STR):
336
+ for element in str_list:
337
+ if STR[i:i+len(element)] == element:
338
+ indices.append(str_list.index(element))
339
+ i += len(element) - 1
340
+ break
341
+ i += 1
342
+ return indices
343
+
344
+
345
+ def format_region_prompt(prompt, refer_input_state):
346
+ # Find regions in prompts and assign corresponding region masks
347
+ refer_input_state['region_masks_in_prompts'] = []
348
+ indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
349
+ refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]
350
+
351
+ # Find regions in prompts and replace with real coordinates and region feature token.
352
+ for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
353
+ prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
354
+ return prompt
355
+
356
+ @spaces.GPU()
357
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_input_state, request: gr.Request):
358
+ # def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
359
+ start_tstamp = time.time()
360
+ model_name = model_selector
361
+
362
+ if state.skip_next:
363
+ # This generate call is skipped due to invalid inputs
364
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
365
+ return
366
+
367
+ print("state messages: ", state.messages)
368
+ if len(state.messages) == state.offset + 2:
369
+ # First round of conversation
370
+ # template_name = 'ferret_v1'
371
+ template_name = 'ferret_gemma_instruct'
372
+ # Below is LLaVA's original templates.
373
+ # if "llava" in model_name.lower():
374
+ # if 'llama-2' in model_name.lower():
375
+ # template_name = "llava_llama_2"
376
+ # elif "v1" in model_name.lower():
377
+ # if 'mmtag' in model_name.lower():
378
+ # template_name = "v1_mmtag"
379
+ # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
380
+ # template_name = "v1_mmtag"
381
+ # else:
382
+ # template_name = "llava_v1"
383
+ # elif "mpt" in model_name.lower():
384
+ # template_name = "mpt"
385
+ # else:
386
+ # if 'mmtag' in model_name.lower():
387
+ # template_name = "v0_mmtag"
388
+ # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
389
+ # template_name = "v0_mmtag"
390
+ # else:
391
+ # template_name = "llava_v0"
392
+ # elif "mpt" in model_name:
393
+ # template_name = "mpt_text"
394
+ # elif "llama-2" in model_name:
395
+ # template_name = "llama_2"
396
+ # else:
397
+ # template_name = "vicuna_v1"
398
+ new_state = conv_templates[template_name].copy()
399
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
400
+ new_state.append_message(new_state.roles[1], None)
401
+ state = new_state
402
+ state.first_round = False
403
+
404
+ # # Query worker address
405
+ # controller_url = args.controller_url
406
+ # ret = requests.post(controller_url + "/get_worker_address",
407
+ # json={"model": model_name})
408
+ # worker_addr = ret.json()["address"]
409
+ # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
410
+
411
+ # No available worker
412
+ # if worker_addr == "":
413
+ # state.messages[-1][-1] = server_error_msg
414
+ # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
415
+ # return
416
+
417
+ # Construct prompt
418
+ prompt = state.get_prompt()
419
+ if args.add_region_feature:
420
+ prompt = format_region_prompt(prompt, refer_input_state)
421
+
422
+ all_images = state.get_images(return_pil=True)
423
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
424
+ for image, hash in zip(all_images, all_image_hash):
425
+ t = datetime.datetime.now()
426
+ # fishy can remove it
427
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
428
+ if not os.path.isfile(filename):
429
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
430
+ image.save(filename)
431
+
432
+ # Make requests
433
+ pload = {
434
+ "model": model_name,
435
+ "prompt": prompt,
436
+ "temperature": float(temperature),
437
+ "top_p": float(top_p),
438
+ "max_new_tokens": min(int(max_new_tokens), 1536),
439
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
440
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
441
+ }
442
+ logger.info(f"==== request ====\n{pload}")
443
+ if args.add_region_feature:
444
+ pload['region_masks'] = refer_input_state['region_masks_in_prompts']
445
+ logger.info(f"==== add region_masks_in_prompts to request ====\n")
446
+
447
+ pload['images'] = state.get_images()
448
+ print(f'Input Prompt: {prompt}')
449
+ print("all_image_hash", all_image_hash)
450
+
451
+ state.messages[-1][-1] = "▌"
452
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
453
+
454
+ try:
455
+ # Stream output
456
+ stop = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
457
+ #TODO: define inference and run function
458
+ results, extracted_texts = inference_and_run(
459
+ image_path=all_image_hash[0], # double check this
460
+ prompt=prompt,
461
+ model_path=model_name,
462
+ conv_mode="ferret_gemma_instruct", # Default mode from the original function
463
+ temperature=temperature,
464
+ top_p=top_p,
465
+ max_new_tokens=max_new_tokens,
466
+ stop=stop # Assuming we want to process the image
467
+ )
468
+
469
+ # response = requests.post(worker_addr + "/worker_generate_stream",
470
+ # headers=headers, json=pload, stream=True, timeout=10)
471
+ response = extracted_texts
472
+ logger.info(f"This is the respone {response}")
473
+
474
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
475
+ if chunk:
476
+ data = json.loads(chunk.decode())
477
+ if data["error_code"] == 0:
478
+ output = data["text"][len(prompt):].strip()
479
+ output = post_process_code(output)
480
+ state.messages[-1][-1] = output + "▌"
481
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
482
+ else:
483
+ output = data["text"] + f" (error_code: {data['error_code']})"
484
+ state.messages[-1][-1] = output
485
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
486
+ return
487
+ time.sleep(0.03)
488
+ except requests.exceptions.RequestException as e:
489
+ state.messages[-1][-1] = server_error_msg
490
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
491
+ return
492
+
493
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
494
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
495
+
496
+ finish_tstamp = time.time()
497
+ logger.info(f"{output}")
498
+
499
+ with open(get_conv_log_filename(), "a") as fout:
500
+ data = {
501
+ "tstamp": round(finish_tstamp, 4),
502
+ "type": "chat",
503
+ "model": model_name,
504
+ "start": round(start_tstamp, 4),
505
+ "finish": round(start_tstamp, 4),
506
+ "state": state.dict(),
507
+ "images": all_image_hash,
508
+ "ip": request.client.host,
509
+ }
510
+ fout.write(json.dumps(data) + "\n")
511
+
512
+ title_markdown = ("""
513
+ # 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
514
+ """)
515
+ # [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)
516
+
517
+ tos_markdown = ("""
518
+ ### Terms of use
519
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
520
+ """)
521
+
522
+
523
+ learn_more_markdown = ("""
524
+ ### License
525
+ The service is a research preview intended for non-commercial use only
526
+ """)
527
+
528
+
529
+ css = code_highlight_css + """
530
+ pre {
531
+ white-space: pre-wrap; /* Since CSS 2.1 */
532
+ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
533
+ white-space: -pre-wrap; /* Opera 4-6 */
534
+ white-space: -o-pre-wrap; /* Opera 7 */
535
+ word-wrap: break-word; /* Internet Explorer 5.5+ */
536
+ }
537
+ """
538
+
539
+ Instructions = '''
540
+ Instructions:
541
+ 1. Select a 'Referring Input Type'
542
+ 2. Draw on the image to refer to a region/point.
543
+ 3. Copy the region id from 'Referring Input Type' to refer to a region in your chat.
544
+ '''
545
+ from gradio.events import Dependency
546
+
547
+ class ImageMask(gr.components.Image):
548
+ """
549
+ Sets: source="canvas", tool="sketch"
550
+ """
551
+
552
+ is_template = True
553
+
554
+ def __init__(self, **kwargs):
555
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
556
+
557
+ def preprocess(self, x):
558
+ return super().preprocess(x)
559
+ from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
560
+ from gradio.blocks import Block
561
+ if TYPE_CHECKING:
562
+ from gradio.components import Timer
563
+
564
+
565
+ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
566
+ if type(input) == dict:
567
+ image = deepcopy(input['image'])
568
+ mask = deepcopy(input['mask'])
569
+ else:
570
+ mask = deepcopy(input)
571
+
572
+ # W, H -> H, W, 3
573
+ image_new = np.asarray(image)
574
+ img_height = image_new.shape[0]
575
+ img_width = image_new.shape[1]
576
+
577
+ # W, H, 4 -> H, W
578
+ mask_new = np.asarray(mask)[:,:,0].copy()
579
+ mask_new = torch.from_numpy(mask_new)
580
+ mask_new = (F.interpolate(mask_new.unsqueeze(0).unsqueeze(0), (img_height, img_width), mode='bilinear') > 0)
581
+ mask_new = mask_new[0, 0].transpose(1, 0).long()
582
+
583
+ if len(refer_input_state['masks']) == 0:
584
+ last_mask = torch.zeros_like(mask_new)
585
+ else:
586
+ last_mask = refer_input_state['masks'][-1]
587
+
588
+ diff_mask = mask_new - last_mask
589
+ if torch.all(diff_mask == 0):
590
+ print('Init Uploading Images.')
591
+ return (refer_input_state, refer_text_show, image)
592
+ else:
593
+ refer_input_state['masks'].append(mask_new)
594
+
595
+ if input_mode == 'Point':
596
+ nonzero_points = diff_mask.nonzero()
597
+ nonzero_points_avg_x = torch.median(nonzero_points[:, 0])
598
+ nonzero_points_avg_y = torch.median(nonzero_points[:, 1])
599
+ sampled_coor = [nonzero_points_avg_x, nonzero_points_avg_y]
600
+ # pdb.set_trace()
601
+ cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
602
+ elif input_mode == 'Box' or input_mode == 'Sketch':
603
+ # pdb.set_trace()
604
+ x1x2 = diff_mask.max(1)[0].nonzero()[:, 0]
605
+ y1y2 = diff_mask.max(0)[0].nonzero()[:, 0]
606
+ y1, y2 = y1y2.min(), y1y2.max()
607
+ x1, x2 = x1x2.min(), x1x2.max()
608
+ # pdb.set_trace()
609
+ sampled_coor = [x1, y1, x2, y2]
610
+ if input_mode == 'Box':
611
+ cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height)
612
+ else:
613
+ cur_region_masks = generate_mask_for_feature(sampled_coor, raw_w=img_width, raw_h=img_height, mask=diff_mask)
614
+ else:
615
+ raise NotImplementedError(f'Input mode of {input_mode} is not Implemented.')
616
+
617
+ # TODO(haoxuan): Hack img_size to be 224 here, need to make it a argument.
618
+ if len(sampled_coor) == 2:
619
+ point_x = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
620
+ point_y = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
621
+ cur_region_coordinates = f'[{int(point_x)}, {int(point_y)}]'
622
+ elif len(sampled_coor) == 4:
623
+ point_x1 = int(VOCAB_IMAGE_W * sampled_coor[0] / img_width)
624
+ point_y1 = int(VOCAB_IMAGE_H * sampled_coor[1] / img_height)
625
+ point_x2 = int(VOCAB_IMAGE_W * sampled_coor[2] / img_width)
626
+ point_y2 = int(VOCAB_IMAGE_H * sampled_coor[3] / img_height)
627
+ cur_region_coordinates = f'[{int(point_x1)}, {int(point_y1)}, {int(point_x2)}, {int(point_y2)}]'
628
+
629
+ cur_region_id = len(refer_input_state['region_placeholder_tokens'])
630
+ cur_region_token = DEFAULT_REGION_REFER_TOKEN.split(']')[0] + str(cur_region_id) + ']'
631
+ refer_input_state['region_placeholder_tokens'].append(cur_region_token)
632
+ refer_input_state['region_coordinates'].append(cur_region_coordinates)
633
+ refer_input_state['region_masks'].append(cur_region_masks)
634
+ assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
635
+ refer_text_show.append((cur_region_token, ''))
636
+
637
+ # Show Parsed Referring.
638
+ imagebox_refer = draw_box(sampled_coor, cur_region_masks, \
639
+ cur_region_token, imagebox_refer, input_mode)
640
+
641
+ return (refer_input_state, refer_text_show, imagebox_refer)
642
+
643
+ def build_demo(embed_mode):
644
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False)
645
+ with gr.Blocks(title="FERRET", theme=gr.themes.Base(), css=css) as demo:
646
+ state = gr.State()
647
+
648
+ if not embed_mode:
649
+ gr.Markdown(title_markdown)
650
+ gr.Markdown(Instructions)
651
+
652
+ with gr.Row():
653
+ with gr.Column(scale=4):
654
+ with gr.Row(elem_id="model_selector_row"):
655
+ model_selector = gr.Dropdown(
656
+ choices=models,
657
+ value=models[0] if len(models) > 0 else "",
658
+ interactive=True,
659
+ show_label=False,
660
+ container=False)
661
+
662
+ original_image = gr.Image(type="pil", visible=False)
663
+ image_process_mode = gr.Radio(
664
+ ["Raw+Processor", "Crop", "Resize", "Pad"],
665
+ value="Raw+Processor",
666
+ label="Preprocess for non-square image",
667
+ visible=False)
668
+
669
+ # Added for any-format input.
670
+ sketch_pad = ImageMask(label="Image & Sketch", type="pil", elem_id="img2text")
671
+ refer_input_mode = gr.Radio(
672
+ ["Point", "Box", "Sketch"],
673
+ value="Point",
674
+ label="Referring Input Type")
675
+ refer_input_state = gr.State({'region_placeholder_tokens':[],
676
+ 'region_coordinates':[],
677
+ 'region_masks':[],
678
+ 'region_masks_in_prompts':[],
679
+ 'masks':[],
680
+ })
681
+ refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")
682
+
683
+ imagebox_refer = gr.Image(type="pil", label="Parsed Referring Input")
684
+ imagebox_output = gr.Image(type="pil", label='Output Vis')
685
+
686
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
687
+ # gr.Examples(examples=[
688
+ # # [f"{cur_dir}/examples/harry-potter-hogwarts.jpg", "What is in [region0]? And what do people use it for?"],
689
+ # # [f"{cur_dir}/examples/ingredients.jpg", "What objects are in [region0] and [region1]?"],
690
+ # # [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image? And tell me the coordinates of mentioned objects."],
691
+ # [f"{cur_dir}/examples/ferret.jpg", "What's the relationship between object [region0] and object [region1]?"],
692
+ # [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here? Tell me the coordinates in response."],
693
+ # [f"{cur_dir}/examples/flickr_9472793441.jpg", "Describe the image in details."],
694
+ # # [f"{cur_dir}/examples/coco_000000281759.jpg", "What are the locations of the woman wearing a blue dress, the woman in flowery top, the girl in purple dress, the girl wearing green shirt?"],
695
+ # [f"{cur_dir}/examples/room_planning.jpg", "How to improve the design of the given room?"],
696
+ # [f"{cur_dir}/examples/make_sandwitch.jpg", "How can I make a sandwich with available ingredients?"],
697
+ # [f"{cur_dir}/examples/bathroom.jpg", "What is unusual about this image?"],
698
+ # [f"{cur_dir}/examples/kitchen.png", "Is the object a man or a chicken? Explain the reason."],
699
+ # ], inputs=[sketch_pad, textbox])
700
+
701
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
702
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
703
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
704
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
705
+
706
+ with gr.Column(scale=5):
707
+ chatbot = gr.Chatbot(elem_id="chatbot", label="FERRET", visible=False).style(height=750)
708
+ with gr.Row():
709
+ with gr.Column(scale=8):
710
+ textbox.render()
711
+ with gr.Column(scale=1, min_width=60):
712
+ submit_btn = gr.Button(value="Submit", visible=False)
713
+ with gr.Row(visible=False) as button_row:
714
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
715
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
716
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
717
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
718
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
719
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
720
+ location_btn = gr.Button(value="🪄 Show location", interactive=False)
721
+
722
+ if not embed_mode:
723
+ gr.Markdown(tos_markdown)
724
+ gr.Markdown(learn_more_markdown)
725
+ url_params = gr.JSON(visible=False)
726
+
727
+ # Register listeners
728
+ btn_list = [upvote_btn, downvote_btn, location_btn, regenerate_btn, clear_btn]
729
+ upvote_btn.click(upvote_last_response,
730
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
731
+ downvote_btn.click(downvote_last_response,
732
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, location_btn])
733
+ # flag_btn.click(flag_last_response,
734
+ # [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
735
+ regenerate_btn.click(regenerate, [state, image_process_mode],
736
+ [state, chatbot, textbox] + btn_list).then(
737
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
738
+ [state, chatbot] + btn_list)
739
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox_output, original_image] + btn_list + \
740
+ [sketch_pad, refer_input_state, refer_text_show, imagebox_refer])
741
+ location_btn.click(show_location,
742
+ [sketch_pad, chatbot], [imagebox_output, chatbot, location_btn])
743
+
744
+
745
+ #TODO: fix bug text and image not adding when clicking submit
746
+ textbox.submit(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
747
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
748
+ [state, chatbot] + btn_list)
749
+
750
+ submit_btn.click(add_text, [state, textbox, image_process_mode, original_image, sketch_pad], [state, chatbot, textbox, original_image] + btn_list
751
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens, refer_input_state],
752
+ [state, chatbot] + btn_list)
753
+
754
+
755
+
756
+ sketch_pad.edit(
757
+ draw,
758
+ inputs=[refer_input_mode, sketch_pad, refer_input_state, refer_text_show, imagebox_refer],
759
+ outputs=[refer_input_state, refer_text_show, imagebox_refer],
760
+ queue=True,
761
+ )
762
+
763
+ if args.model_list_mode == "once":
764
+ demo.load(load_demo, [url_params], [state, model_selector,
765
+ chatbot, textbox, submit_btn, button_row, parameter_row],
766
+ _js=get_window_url_params)
767
+ elif args.model_list_mode == "reload":
768
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector,
769
+ chatbot, textbox, submit_btn, button_row, parameter_row])
770
+ else:
771
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
772
+
773
+ return demo
774
+
775
+
776
+ if __name__ == "__main__":
777
+ parser = argparse.ArgumentParser()
778
+ parser.add_argument("--host", type=str, default="0.0.0.0")
779
+ parser.add_argument("--port", type=int)
780
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
781
+ parser.add_argument("--concurrency-count", type=int, default=8)
782
+ parser.add_argument("--model-list-mode", type=str, default="once",
783
+ choices=["once", "reload"])
784
+ parser.add_argument("--share", action="store_true")
785
+ parser.add_argument("--moderate", action="store_true")
786
+ parser.add_argument("--embed", action="store_true")
787
+ parser.add_argument("--add_region_feature", action="store_true")
788
+ args = parser.parse_args()
789
+ logger.info(f"args: {args}")
790
+
791
+ models = get_model_list()
792
+
793
+ logger.info(args)
794
+ demo = build_demo(args.embed)
795
+ demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
796
+ api_open=False).launch(
797
+ server_name=args.host, server_port=args.port, share=True)
eval.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"id": 0, "image": "8b23f327b90b6211049acd36e3f99975.jpg", "image_h": 433, "image_w": 400, "conversations": [{"from": "human", "value": "<image>\nA chat between a human and an AI that understands visuals. In images, [x, y] denotes points: top-left [0, 0], bottom-right [width-1, height-1]. Increasing x moves right; y moves down. Bounding box: [x1, y1, x2, y2]. Image size: 1000x1000. Follow instructions.<start_of_turn>user\n<image>\ndescribe the image in details<end_of_turn>\n<start_of_turn>model\n"}]}]
gradio_web_server.log ADDED
The diff for this file is too large to render. See raw diff
 
inference.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image, ImageDraw
5
  import re
6
  import json
7
  import subprocess
 
8
 
9
  def process_inference_results(results, process_image=False):
10
  """
@@ -38,8 +39,9 @@ def process_inference_results(results, process_image=False):
38
  return processed_images, extracted_texts
39
 
40
  return extracted_texts
41
-
42
- def inference_and_run(image_path, prompt, conv_mode="ferret_gemma_instruct", model_path="jadechoghari/Ferret-UI-Gemma2b", box=None, process_image=False):
 
43
  """
44
  Run the inference and capture the errors for debugging.
45
  """
@@ -63,10 +65,12 @@ def inference_and_run(image_path, prompt, conv_mode="ferret_gemma_instruct", mod
63
  "python", "-m", "model_UI",
64
  "--model_path", model_path,
65
  "--data_path", "eval.json",
66
- "--image_path", ".",
67
  "--answers_file", "eval_output.jsonl",
68
  "--num_beam", "1",
69
- "--max_new_tokens", "32",
 
 
70
  "--conv_mode", conv_mode
71
  ]
72
 
@@ -98,4 +102,4 @@ def inference_and_run(image_path, prompt, conv_mode="ferret_gemma_instruct", mod
98
  except subprocess.CalledProcessError as e:
99
  print(f"Error occurred during inference:\n{e}")
100
  print(f"Subprocess output:\n{e.output}")
101
- return None, None
 
5
  import re
6
  import json
7
  import subprocess
8
+ import spaces
9
 
10
  def process_inference_results(results, process_image=False):
11
  """
 
39
  return processed_images, extracted_texts
40
 
41
  return extracted_texts
42
+
43
+ @spaces.GPU()
44
+ def inference_and_run(image_dir, image_path, prompt, conv_mode="ferret_gemma_instruct", model_path="jadechoghari/Ferret-UI-Gemma2b", box=None, process_image=False, temperature=0.2, top_p=0.7, max_new_tokens=512, stop='<eos>'):
45
  """
46
  Run the inference and capture the errors for debugging.
47
  """
 
65
  "python", "-m", "model_UI",
66
  "--model_path", model_path,
67
  "--data_path", "eval.json",
68
+ "--image_path", image_dir,
69
  "--answers_file", "eval_output.jsonl",
70
  "--num_beam", "1",
71
+ "--temperature", str(temperature),
72
+ "--top_p", str(top_p),
73
+ "--max_new_tokens", str(max_new_tokens),
74
  "--conv_mode", conv_mode
75
  ]
76
 
 
102
  except subprocess.CalledProcessError as e:
103
  print(f"Error occurred during inference:\n{e}")
104
  print(f"Subprocess output:\n{e.output}")
105
+ return None, None
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ fastapi
3
+ gradio==4.31.3
4
+ markdown2[all]
5
+ numpy
6
+ requests
7
+ sentencepiece==0.1.99
8
+ tokenizers>=0.12.1
9
+ torch
10
+ torchvision
11
+ uvicorn
12
+ wandb
13
+ shortuuid
14
+ httpx==0.24.0
15
+ deepspeed==0.9.5
16
+ peft==0.4.0
17
+ transformers @ git+https://github.com/huggingface/transformers.git@cae78c46
18
+ accelerate==0.21.0
19
+ bitsandbytes==0.41.0
20
+ scikit-learn==1.5.0
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.6.13
24
+ openai
25
+ gradio_client==0.1.2
serve_images/2024-10-19/8b23f327b90b6211049acd36e3f99975.jpg ADDED