Abhaykoul commited on
Commit
fcae467
1 Parent(s): 49a2a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -25
app.py CHANGED
@@ -1,39 +1,54 @@
1
  from __future__ import annotations
2
 
3
- import os
 
 
 
 
4
  import hashlib
 
 
 
5
  import torch
6
- from threading import Thread
7
- from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
8
- import gradio as gr
9
 
10
- # Initialize the model and processor
11
- def initialize_model_and_processor():
12
- model = AutoModel.from_pretrained("OEvortex/HelpingAI-Vision", torch_dtype=torch.float16, trust_remote_code=True).to("cuda" if torch.cuda.is_available() else "cpu")
13
- processor = AutoProcessor.from_pretrained("OEvortex/HelpingAI-Vision", trust_remote_code=True)
14
- return model, processor
 
 
 
 
 
15
 
16
- # Function to process images and cache results
17
  def cached_vision_process(image, max_crops, num_tokens):
18
  image_hash = hashlib.sha256(image.tobytes()).hexdigest()
19
  cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
20
  if os.path.exists(cache_path):
21
- return torch.load(cache_path).to(model.device, dtype=model.dtype)
22
  else:
23
  processor_outputs = processor.image_processor([image], max_crops)
24
- pixel_values = [value.to(model.device, model.dtype) for value in processor_outputs["pixel_values"]]
25
- coords = [value.to(model.device, model.dtype) for value in processor_outputs["coords"]]
 
 
 
 
26
  image_outputs = model.vision_model(pixel_values, coords, num_tokens)
27
  image_features = model.multi_modal_projector(image_outputs)
28
  os.makedirs("visual_cache", exist_ok=True)
29
  torch.save(image_features, cache_path)
30
- return image_features.to(model.device, model.dtype)
31
 
32
- # Function to answer questions about images
33
  def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
34
- if not question.strip() or not image:
35
- return "Please provide both an image and a question."
36
-
 
 
 
37
  prompt = f"""user
38
  <image>
39
  {question}
@@ -42,7 +57,6 @@ assistant
42
  streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
43
  with torch.inference_mode():
44
  inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
45
-
46
  generation_kwargs = {
47
  "input_ids": inputs["input_ids"],
48
  "attention_mask": inputs["attention_mask"],
@@ -71,15 +85,17 @@ assistant
71
  yield buffer
72
  return buffer
73
 
74
- # Initialize the model and processor
75
- model, processor = initialize_model_and_processor()
76
 
77
- # Gradio interface setup
78
  with gr.Blocks() as demo:
79
  with gr.Group():
80
  with gr.Row():
81
- prompt = gr.Textbox(label="Question", placeholder="e.g. Describe this?", scale=4)
82
- submit = gr.Button("Send", scale=1)
 
 
 
 
 
83
  with gr.Row():
84
  max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
85
  num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
@@ -94,4 +110,4 @@ with gr.Blocks() as demo:
94
  submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
95
  prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
96
 
97
- demo.queue().launch(debug=True)
 
1
  from __future__ import annotations
2
 
3
+ import spaces
4
+
5
+ import gradio as gr
6
+ from threading import Thread
7
+ from transformers import TextIteratorStreamer
8
  import hashlib
9
+ import os
10
+
11
+ from transformers import AutoModel, AutoProcessor
12
  import torch
 
 
 
13
 
14
+ model = AutoModel.from_pretrained("OEvortex/HelpingAI-Vision", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")
15
+
16
+ processor = AutoProcessor.from_pretrained("OEvortex/HelpingAI-Vision", trust_remote_code=True)
17
+
18
+ if torch.cuda.is_available():
19
+ DEVICE = "cuda"
20
+ DTYPE = torch.float16
21
+ else:
22
+ DEVICE = "cpu"
23
+ DTYPE = torch.float32
24
 
 
25
  def cached_vision_process(image, max_crops, num_tokens):
26
  image_hash = hashlib.sha256(image.tobytes()).hexdigest()
27
  cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
28
  if os.path.exists(cache_path):
29
+ return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
30
  else:
31
  processor_outputs = processor.image_processor([image], max_crops)
32
+ pixel_values = processor_outputs["pixel_values"]
33
+ pixel_values = [
34
+ value.to(model.device).to(model.dtype) for value in pixel_values
35
+ ]
36
+ coords = processor_outputs["coords"]
37
+ coords = [value.to(model.device).to(model.dtype) for value in coords]
38
  image_outputs = model.vision_model(pixel_values, coords, num_tokens)
39
  image_features = model.multi_modal_projector(image_outputs)
40
  os.makedirs("visual_cache", exist_ok=True)
41
  torch.save(image_features, cache_path)
42
+ return image_features.to(DEVICE, dtype=DTYPE)
43
 
44
+ @spaces.GPU(duration=20)
45
  def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
46
+ if question is None or question.strip() == "":
47
+ yield "Please ask me anything"
48
+ return
49
+ if image is None:
50
+ yield "Please upload a picture"
51
+ return
52
  prompt = f"""user
53
  <image>
54
  {question}
 
57
  streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
58
  with torch.inference_mode():
59
  inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
 
60
  generation_kwargs = {
61
  "input_ids": inputs["input_ids"],
62
  "attention_mask": inputs["attention_mask"],
 
85
  yield buffer
86
  return buffer
87
 
 
 
88
 
 
89
  with gr.Blocks() as demo:
90
  with gr.Group():
91
  with gr.Row():
92
+ prompt = gr.Textbox(
93
+ label="Question", placeholder="e.g. Discribe this?", scale=4
94
+ )
95
+ submit = gr.Button(
96
+ "Send",
97
+ scale=1,
98
+ )
99
  with gr.Row():
100
  max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
101
  num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
 
110
  submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
111
  prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
112
 
113
+ demo.queue().launch(debug=True)