paralym commited on
Commit
3eda1dd
·
verified ·
1 Parent(s): 5cddf68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -9
app.py CHANGED
@@ -6,7 +6,7 @@ from threading import Thread
6
  # import time
7
  import cv2
8
 
9
-
10
  # import copy
11
  import torch
12
 
@@ -34,8 +34,6 @@ from llava.mm_utils import (
34
 
35
  from serve_constants import html_header
36
 
37
- from PIL import Image
38
-
39
  import requests
40
  from PIL import Image
41
  from io import BytesIO
@@ -46,6 +44,9 @@ import gradio_client
46
  import subprocess
47
  import sys
48
 
 
 
 
49
  def install_gradio_4_35_0():
50
  current_version = gr.__version__
51
  if current_version != "4.35.0":
@@ -64,6 +65,11 @@ import gradio_client
64
  print(f"Gradio version: {gr.__version__}")
65
  print(f"Gradio-client version: {gradio_client.__version__}")
66
 
 
 
 
 
 
67
  class InferenceDemo(object):
68
  def __init__(
69
  self, args, model_path, tokenizer, model, image_processor, context_len
@@ -113,6 +119,16 @@ def is_valid_video_filename(name):
113
  else:
114
  return False
115
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def sample_frames(video_file, num_frames):
118
  video = cv2.VideoCapture(video_file)
@@ -193,9 +209,14 @@ def bot(history):
193
  if type(message[0]) is tuple:
194
  images_this_term.append(message[0][0])
195
  if is_valid_video_filename(message[0][0]):
 
 
196
  num_new_images += our_chatbot.num_frames
197
- else:
 
198
  num_new_images += 1
 
 
199
  else:
200
  num_new_images = 0
201
 
@@ -209,8 +230,11 @@ def bot(history):
209
  for f in images_this_term:
210
  if is_valid_video_filename(f):
211
  image_list += sample_frames(f, our_chatbot.num_frames)
212
- else:
213
  image_list.append(load_image(f))
 
 
 
214
  image_tensor = [
215
  our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
216
  0
@@ -219,6 +243,24 @@ def bot(history):
219
  .to(our_chatbot.model.device)
220
  for f in image_list
221
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  image_tensor = torch.stack(image_tensor)
224
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
@@ -280,7 +322,19 @@ def bot(history):
280
  our_chatbot.conversation.messages[-1][-1] = outputs
281
 
282
  history[-1] = [text, outputs]
283
-
 
 
 
 
 
 
 
 
 
 
 
 
284
  return history
285
  # generate_kwargs = dict(
286
  # inputs=input_ids,
@@ -345,7 +399,7 @@ with gr.Blocks(
345
 
346
  with gr.Column():
347
  with gr.Row():
348
- chatbot = gr.Chatbot([], elem_id="chatbot", bubble_full_width=False)
349
 
350
  with gr.Row():
351
  upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
@@ -560,8 +614,8 @@ if __name__ == "__main__":
560
  argparser.add_argument("--model-base", type=str, default=None)
561
  argparser.add_argument("--num-gpus", type=int, default=1)
562
  argparser.add_argument("--conv-mode", type=str, default=None)
563
- argparser.add_argument("--temperature", type=float, default=0.2)
564
- argparser.add_argument("--max-new-tokens", type=int, default=512)
565
  argparser.add_argument("--num_frames", type=int, default=16)
566
  argparser.add_argument("--load-8bit", action="store_true")
567
  argparser.add_argument("--load-4bit", action="store_true")
 
6
  # import time
7
  import cv2
8
 
9
+ import datetime
10
  # import copy
11
  import torch
12
 
 
34
 
35
  from serve_constants import html_header
36
 
 
 
37
  import requests
38
  from PIL import Image
39
  from io import BytesIO
 
44
  import subprocess
45
  import sys
46
 
47
+ external_log_dir = "./logs"
48
+ LOGDIR = external_log_dir
49
+
50
  def install_gradio_4_35_0():
51
  current_version = gr.__version__
52
  if current_version != "4.35.0":
 
65
  print(f"Gradio version: {gr.__version__}")
66
  print(f"Gradio-client version: {gradio_client.__version__}")
67
 
68
+ def get_conv_log_filename():
69
+ t = datetime.datetime.now()
70
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
71
+ return name
72
+
73
  class InferenceDemo(object):
74
  def __init__(
75
  self, args, model_path, tokenizer, model, image_processor, context_len
 
119
  else:
120
  return False
121
 
122
+ def is_valid_image_filename(name):
123
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
124
+
125
+ ext = name.split(".")[-1].lower()
126
+
127
+ if ext in image_extensions:
128
+ return True
129
+ else:
130
+ return False
131
+
132
 
133
  def sample_frames(video_file, num_frames):
134
  video = cv2.VideoCapture(video_file)
 
209
  if type(message[0]) is tuple:
210
  images_this_term.append(message[0][0])
211
  if is_valid_video_filename(message[0][0]):
212
+ # 不接受视频
213
+ raise ValueError("Video is not supported")
214
  num_new_images += our_chatbot.num_frames
215
+ elif is_valid_image_filename(message[0][0]):
216
+ print("#### Load image from local file",message[0][0])
217
  num_new_images += 1
218
+ else:
219
+ raise ValueError("Invalid image file")
220
  else:
221
  num_new_images = 0
222
 
 
230
  for f in images_this_term:
231
  if is_valid_video_filename(f):
232
  image_list += sample_frames(f, our_chatbot.num_frames)
233
+ elif is_valid_image_filename(f):
234
  image_list.append(load_image(f))
235
+ else:
236
+ raise ValueError("Invalid image file")
237
+
238
  image_tensor = [
239
  our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
240
  0
 
243
  .to(our_chatbot.model.device)
244
  for f in image_list
245
  ]
246
+ all_image_hash = []
247
+ for image_path in image_list:
248
+ with open(image_path, "rb") as image_file:
249
+ image_data = image_file.read()
250
+ image_hash = hashlib.md5(image_data).hexdigest()
251
+ all_image_hash.append(image_hash)
252
+ image = PIL.Image.open(image_path).convert("RGB")
253
+ all_images.append(image)
254
+ t = datetime.datetime.now()
255
+ filename = os.path.join(
256
+ LOGDIR,
257
+ "serve_images",
258
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
259
+ f"{image_hash}.jpg",
260
+ )
261
+ if not os.path.isfile(filename):
262
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
263
+ image.save(filename)
264
 
265
  image_tensor = torch.stack(image_tensor)
266
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
 
322
  our_chatbot.conversation.messages[-1][-1] = outputs
323
 
324
  history[-1] = [text, outputs]
325
+ print("#### history",history)
326
+
327
+ with open(get_conv_log_filename(), "a") as fout:
328
+ data = {
329
+ "tstamp": round(finish_tstamp, 4),
330
+ "type": "chat",
331
+ "model": "Pangea-7b",
332
+ "start": round(start_tstamp, 4),
333
+ "finish": round(start_tstamp, 4),
334
+ "state": history,
335
+ "images": all_image_hash,
336
+ }
337
+ fout.write(json.dumps(data) + "\n")
338
  return history
339
  # generate_kwargs = dict(
340
  # inputs=input_ids,
 
399
 
400
  with gr.Column():
401
  with gr.Row():
402
+ chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750)
403
 
404
  with gr.Row():
405
  upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
 
614
  argparser.add_argument("--model-base", type=str, default=None)
615
  argparser.add_argument("--num-gpus", type=int, default=1)
616
  argparser.add_argument("--conv-mode", type=str, default=None)
617
+ argparser.add_argument("--temperature", type=float, default=0.7)
618
+ argparser.add_argument("--max-new-tokens", type=int, default=4096)
619
  argparser.add_argument("--num_frames", type=int, default=16)
620
  argparser.add_argument("--load-8bit", action="store_true")
621
  argparser.add_argument("--load-4bit", action="store_true")