prithivMLmods commited on
Commit
6c89696
Β·
verified Β·
1 Parent(s): de8d16b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -119
app.py CHANGED
@@ -1,125 +1,144 @@
1
- import os
2
- from collections.abc import Iterator
3
- from threading import Thread
4
  import gradio as gr
5
  import spaces
 
 
6
  import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
8
 
9
- DESCRIPTION = """
10
- # GWQ PREV
11
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- MAX_MAX_NEW_TOKENS = 2048
14
- DEFAULT_MAX_NEW_TOKENS = 1024
15
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
-
17
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
-
19
- model_id = "prithivMLmods/GWQ2b"
20
- tokenizer = AutoTokenizer.from_pretrained(model_id)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_id,
23
- device_map="auto",
24
- torch_dtype=torch.bfloat16,
25
- )
26
- model.config.sliding_window = 4096
27
- model.eval()
28
-
29
-
30
- @spaces.GPU(duration=120)
31
- def generate(
32
- message: str,
33
- chat_history: list[dict],
34
- max_new_tokens: int = 1024,
35
- temperature: float = 0.6,
36
- top_p: float = 0.9,
37
- top_k: int = 50,
38
- repetition_penalty: float = 1.2,
39
- ) -> Iterator[str]:
40
- conversation = chat_history.copy()
41
- conversation.append({"role": "user", "content": message})
42
-
43
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
44
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
45
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
46
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
47
- input_ids = input_ids.to(model.device)
48
-
49
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
50
- generate_kwargs = dict(
51
- {"input_ids": input_ids},
52
- streamer=streamer,
53
- max_new_tokens=max_new_tokens,
54
- do_sample=True,
55
- top_p=top_p,
56
- top_k=top_k,
57
- temperature=temperature,
58
- num_beams=1,
59
- repetition_penalty=repetition_penalty,
60
  )
61
- t = Thread(target=model.generate, kwargs=generate_kwargs)
62
- t.start()
63
-
64
- outputs = []
65
- for text in streamer:
66
- outputs.append(text)
67
- yield "".join(outputs)
68
-
69
-
70
- demo = gr.ChatInterface(
71
- fn=generate,
72
- additional_inputs=[
73
- gr.Slider(
74
- label="Max new tokens",
75
- minimum=1,
76
- maximum=MAX_MAX_NEW_TOKENS,
77
- step=1,
78
- value=DEFAULT_MAX_NEW_TOKENS,
79
- ),
80
- gr.Slider(
81
- label="Temperature",
82
- minimum=0.1,
83
- maximum=4.0,
84
- step=0.1,
85
- value=0.6,
86
- ),
87
- gr.Slider(
88
- label="Top-p (nucleus sampling)",
89
- minimum=0.05,
90
- maximum=1.0,
91
- step=0.05,
92
- value=0.9,
93
- ),
94
- gr.Slider(
95
- label="Top-k",
96
- minimum=1,
97
- maximum=1000,
98
- step=1,
99
- value=50,
100
- ),
101
- gr.Slider(
102
- label="Repetition penalty",
103
- minimum=1.0,
104
- maximum=2.0,
105
- step=0.05,
106
- value=1.2,
107
- ),
108
- ],
109
- stop_btn=None,
110
- examples=[
111
- ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
112
- ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
113
- ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
114
- ["What happens when the sun goes down?"],
115
- ],
116
- cache_examples=False,
117
- type="messages",
118
- description=DESCRIPTION,
119
- css_paths="style.css",
120
- fill_height=True,
121
- )
122
-
123
-
124
- if __name__ == "__main__":
125
- demo.queue(max_size=20).launch()
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
  import torch
6
+ from PIL import Image
7
+ import subprocess
8
+ import numpy as np
9
+ import os
10
+ from threading import Thread
11
+ import uuid
12
+ import io
13
 
14
+ # Model and Processor Loading (Done once at startup)
15
+ MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
16
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
+ MODEL_ID,
18
+ trust_remote_code=True,
19
+ torch_dtype=torch.float16
20
+ ).to("cuda").eval()
21
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
22
+
23
+ DESCRIPTION = "# **Qwen2.5-VL-3B-Instruct**"
24
+
25
+ image_extensions = Image.registered_extensions()
26
+ video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
27
+
28
+
29
+ def identify_and_save_blob(blob_path):
30
+ """Identifies if the blob is an image or video and saves it accordingly."""
31
+ try:
32
+ with open(blob_path, 'rb') as file:
33
+ blob_content = file.read()
34
+
35
+ # Try to identify if it's an image
36
+ try:
37
+ Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
38
+ extension = ".png" # Default to PNG for saving
39
+ media_type = "image"
40
+ except (IOError, SyntaxError):
41
+ # If it's not a valid image, assume it's a video
42
+ extension = ".mp4" # Default to MP4 for saving
43
+ media_type = "video"
44
+
45
+ # Create a unique filename
46
+ filename = f"temp_{uuid.uuid4()}_media{extension}"
47
+ with open(filename, "wb") as f:
48
+ f.write(blob_content)
49
+
50
+ return filename, media_type
51
+
52
+ except FileNotFoundError:
53
+ raise ValueError(f"The file {blob_path} was not found.")
54
+ except Exception as e:
55
+ raise ValueError(f"An error occurred while processing the file: {e}")
56
+
57
+
58
+ @spaces.GPU
59
+ def qwen_inference(media_input, text_input=None):
60
+ if isinstance(media_input, str): # If it's a filepath
61
+ media_path = media_input
62
+ if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
63
+ media_type = "image"
64
+ elif media_path.endswith(video_extensions):
65
+ media_type = "video"
66
+ else:
67
+ try:
68
+ media_path, media_type = identify_and_save_blob(media_input)
69
+ print(media_path, media_type)
70
+ except Exception as e:
71
+ print(e)
72
+ raise ValueError(
73
+ "Unsupported media type. Please upload an image or video."
74
+ )
75
+
76
+
77
+ print(media_path)
78
 
79
+ messages = [
80
+ {
81
+ "role": "user",
82
+ "content": [
83
+ {
84
+ "type": media_type,
85
+ media_type: media_path,
86
+ **({"fps": 8.0} if media_type == "video" else {}),
87
+ },
88
+ {"type": "text", "text": text_input},
89
+ ],
90
+ }
91
+ ]
92
+
93
+ text = processor.apply_chat_template(
94
+ messages, tokenize=False, add_generation_prompt=True
95
+ )
96
+ image_inputs, video_inputs = process_vision_info(messages)
97
+ inputs = processor(
98
+ text=[text],
99
+ images=image_inputs,
100
+ videos=video_inputs,
101
+ padding=True,
102
+ return_tensors="pt",
103
+ ).to("cuda")
104
+
105
+ streamer = TextIteratorStreamer(
106
+ processor, skip_prompt=True, **{"skip_special_tokens": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
108
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
109
+
110
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
+ thread.start()
112
+
113
+ buffer = ""
114
+ for new_text in streamer:
115
+ buffer += new_text
116
+ yield buffer
117
+
118
+ css = """
119
+ #output {
120
+ height: 500px;
121
+ overflow: auto;
122
+ border: 1px solid #ccc;
123
+ }
124
+ """
125
+
126
+ with gr.Blocks(css=css) as demo:
127
+ gr.Markdown(DESCRIPTION)
128
+
129
+ with gr.Tab(label="Image/Video Input"):
130
+ with gr.Row():
131
+ with gr.Column():
132
+ input_media = gr.File(
133
+ label="Upload Image or Video", type="filepath"
134
+ )
135
+ text_input = gr.Textbox(label="Question")
136
+ submit_btn = gr.Button(value="Submit")
137
+ with gr.Column():
138
+ output_text = gr.Textbox(label="Output Text")
139
+
140
+ submit_btn.click(
141
+ qwen_inference, [input_media, text_input], [output_text]
142
+ )
143
+
144
+ demo.launch(debug=True)