prithivMLmods commited on
Commit
ea33f68
·
verified ·
1 Parent(s): 358adb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -97
app.py CHANGED
@@ -6,55 +6,33 @@ import time
6
  import torch
7
  import spaces
8
 
9
- DESCRIPTION = """
10
- # Qwen2.5-VL-3B/7B-Instruct
11
- """
12
-
13
- css = '''
14
- h1 {
15
- text-align: center;
16
- display: block;
17
- }
18
- #duplicate-button {
19
- margin: auto;
20
- color: #fff;
21
- background: #1565c0;
22
- border-radius: 100vh;
23
- }
24
- '''
25
-
26
- # Define an animated progress bar HTML snippet
27
  def progress_bar_html(label: str) -> str:
 
 
 
 
28
  return f'''
29
- <div style="display: flex; align-items: center;">
30
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
31
- <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
32
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
33
- </div>
34
  </div>
35
- <style>
36
- @keyframes loading {{
37
- 0% {{ transform: translateX(-100%); }}
38
- 100% {{ transform: translateX(100%); }}
39
- }}
40
- </style>
 
41
  '''
42
 
43
- # Model IDs for 3B and 7B variants
44
- MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
45
- MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
46
-
47
- # Load the processor and models for both versions
48
- processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
49
- model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
- MODEL_ID_3B,
51
- trust_remote_code=True,
52
- torch_dtype=torch.bfloat16
53
- ).to("cuda").eval()
54
-
55
- processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True)
56
- model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
- MODEL_ID_7B,
58
  trust_remote_code=True,
59
  torch_dtype=torch.bfloat16
60
  ).to("cuda").eval()
@@ -64,89 +42,73 @@ def model_inference(input_dict, history):
64
  text = input_dict["text"]
65
  files = input_dict["files"]
66
 
67
- # Determine which model to use based on the prefix tag
68
- if text.lower().startswith("@3b"):
69
- yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct")
70
- selected_model = model_3b
71
- selected_processor = processor_3b
72
- text = text[len("@3b"):].strip()
73
- elif text.lower().startswith("@7b"):
74
- yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct")
75
- selected_model = model_7b
76
- selected_processor = processor_7b
77
- text = text[len("@7b"):].strip()
78
- else:
79
- yield "Error: Please prefix your query with @3b or @7b to select the model."
80
- return
81
-
82
  # Load images if provided
83
- if files:
84
- if isinstance(files, list):
85
- if len(files) > 1:
86
- images = [load_image(image) for image in files]
87
- elif len(files) == 1:
88
- images = [load_image(files[0])]
89
- else:
90
- images = []
91
- else:
92
- images = [load_image(files)]
93
  else:
94
  images = []
95
 
96
- # Validate input: text query is required
97
- if text == "":
98
- yield "Error: Please input a text query along with the image(s) if any."
 
 
 
99
  return
100
 
101
  # Prepare messages for the model
102
- messages = [{
103
- "role": "user",
104
- "content": [
105
- *[{"type": "image", "image": image} for image in images],
106
- {"type": "text", "text": text},
107
- ]
108
- }]
109
-
110
- # Apply the chat template and process the inputs
111
- prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
112
- inputs = selected_processor(
 
 
113
  text=[prompt],
114
  images=images if images else None,
115
  return_tensors="pt",
116
  padding=True,
117
  ).to("cuda")
118
 
119
- # Set up a streamer for real-time text generation
120
- streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True)
121
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
122
 
123
  # Start generation in a separate thread
124
- thread = Thread(target=selected_model.generate, kwargs=generation_kwargs)
125
  thread.start()
126
 
127
- # Yield an animated progress message
128
- yield progress_bar_html("Almost there, hold tight!")
129
-
130
  buffer = ""
 
131
  for new_text in streamer:
132
  buffer += new_text
133
  time.sleep(0.01)
134
  yield buffer
135
 
136
- # Example inputs with model prefixes
 
137
  examples = [
138
- [{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}],
139
- [{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}],
140
- [{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}],
141
- [{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
 
142
  ]
143
 
144
  demo = gr.ChatInterface(
145
  fn=model_inference,
146
- description=DESCRIPTION,
147
- css=css,
148
  examples=examples,
149
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"),
150
  stop_btn="Stop Generation",
151
  multimodal=True,
152
  cache_examples=False,
 
6
  import torch
7
  import spaces
8
 
9
+ # -----------------------
10
+ # Progress Bar Helper
11
+ # -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def progress_bar_html(label: str) -> str:
13
+ """
14
+ Returns an HTML snippet for a thin progress bar with a label.
15
+ The progress bar is styled as a dark red animated bar.
16
+ """
17
  return f'''
18
+ <div style="display: flex; align-items: center;">
19
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
20
+ <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
21
+ <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
 
22
  </div>
23
+ </div>
24
+ <style>
25
+ @keyframes loading {{
26
+ 0% {{ transform: translateX(-100%); }}
27
+ 100% {{ transform: translateX(100%); }}
28
+ }}
29
+ </style>
30
  '''
31
 
32
+ MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" #else ; MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
33
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
34
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ MODEL_ID,
 
 
 
 
 
 
 
 
 
 
 
36
  trust_remote_code=True,
37
  torch_dtype=torch.bfloat16
38
  ).to("cuda").eval()
 
42
  text = input_dict["text"]
43
  files = input_dict["files"]
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Load images if provided
46
+ if len(files) > 1:
47
+ images = [load_image(image) for image in files]
48
+ elif len(files) == 1:
49
+ images = [load_image(files[0])]
 
 
 
 
 
 
50
  else:
51
  images = []
52
 
53
+ # Validate input
54
+ if text == "" and not images:
55
+ gr.Error("Please input a query and optionally image(s).")
56
+ return
57
+ if text == "" and images:
58
+ gr.Error("Please input a text query along with the image(s).")
59
  return
60
 
61
  # Prepare messages for the model
62
+ messages = [
63
+ {
64
+ "role": "user",
65
+ "content": [
66
+ *[{"type": "image", "image": image} for image in images],
67
+ {"type": "text", "text": text},
68
+ ],
69
+ }
70
+ ]
71
+
72
+ # Apply chat template and process inputs
73
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
+ inputs = processor(
75
  text=[prompt],
76
  images=images if images else None,
77
  return_tensors="pt",
78
  padding=True,
79
  ).to("cuda")
80
 
81
+ # Set up streamer for real-time output
82
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
83
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
84
 
85
  # Start generation in a separate thread
86
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
87
  thread.start()
88
 
89
+ # Stream the output
 
 
90
  buffer = ""
91
+ yield "Thinking..."
92
  for new_text in streamer:
93
  buffer += new_text
94
  time.sleep(0.01)
95
  yield buffer
96
 
97
+
98
+ # Example inputs
99
  examples = [
100
+ [{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
101
+ [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
102
+ [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
103
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
104
+
105
  ]
106
 
107
  demo = gr.ChatInterface(
108
  fn=model_inference,
109
+ description="# **Qwen2.5-VL-7B-Instruct**",
 
110
  examples=examples,
111
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
112
  stop_btn="Stop Generation",
113
  multimodal=True,
114
  cache_examples=False,