Ivy1997 commited on
Commit
dd3b350
1 Parent(s): 4aa36bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -10,19 +10,18 @@ import warnings
10
 
11
  warnings.filterwarnings("ignore")
12
 
13
- pretrained = "AI-Safeguard/Ivy-VL-llava"
14
  model_name = "llava_qwen"
15
- device = "cpu"
16
  device_map = "auto"
17
 
18
  # Load model, tokenizer, and image processor
19
  tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
20
  model.eval()
21
 
22
- def respond(image_path, question, temperature, max_tokens):
23
  try:
24
  # Load and process the image
25
- image = Image.open(image_path)
26
  image_tensor = process_images([image], image_processor, model.config)
27
  image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
28
 
@@ -57,12 +56,12 @@ def respond(image_path, question, temperature, max_tokens):
57
  def chat_interface(image, question, temperature, max_tokens):
58
  if not image or not question:
59
  return "Please provide both an image and a question."
60
- return respond(image.name, question, temperature, max_tokens)
61
 
62
  demo = gr.Interface(
63
  fn=chat_interface,
64
  inputs=[
65
- gr.Image(type="file", label="Input Image"),
66
  gr.Textbox(label="Question"),
67
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
68
  gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max Tokens"),
 
10
 
11
  warnings.filterwarnings("ignore")
12
 
13
+ pretrained = "/tmp/pre-trained/AI-Safeguard/Ivy-VL-llava"
14
  model_name = "llava_qwen"
15
+ device = "cuda"
16
  device_map = "auto"
17
 
18
  # Load model, tokenizer, and image processor
19
  tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
20
  model.eval()
21
 
22
+ def respond(image, question, temperature, max_tokens):
23
  try:
24
  # Load and process the image
 
25
  image_tensor = process_images([image], image_processor, model.config)
26
  image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
27
 
 
56
  def chat_interface(image, question, temperature, max_tokens):
57
  if not image or not question:
58
  return "Please provide both an image and a question."
59
+ return respond(image, question, temperature, max_tokens)
60
 
61
  demo = gr.Interface(
62
  fn=chat_interface,
63
  inputs=[
64
+ gr.Image(type="pil", label="Input Image"),
65
  gr.Textbox(label="Question"),
66
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
67
  gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max Tokens"),