sagar007 commited on
Commit
8e31ab1
·
verified ·
1 Parent(s): ecabb86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoConfig, AutoModel
4
+ from PIL import Image
5
+ import logging
6
+ from transformers import BitsAndBytesConfig
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+ class LLaVAPhiModel:
12
+ def __init__(self, model_id="sagar007/Lava_phi"):
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ logging.info(f"Using device: {self.device}")
15
+
16
+ # Initialize quantization config
17
+ quantization_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_compute_dtype=torch.float16,
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_quant_type="nf4"
22
+ )
23
+
24
+ try:
25
+ # Load model directly from Hugging Face Hub
26
+ logging.info(f"Loading model from {model_id}...")
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ quantization_config=quantization_config,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16,
32
+ trust_remote_code=True
33
+ )
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+
36
+ # Set up padding token
37
+ if self.tokenizer.pad_token is None:
38
+ self.tokenizer.pad_token = self.tokenizer.eos_token
39
+ self.model.config.pad_token_id = self.tokenizer.eos_token_id
40
+
41
+ # Load CLIP model and processor
42
+ logging.info("Loading CLIP model and processor...")
43
+ self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
44
+ self.clip = AutoModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
45
+
46
+ # Store conversation history
47
+ self.history = []
48
+
49
+ except Exception as e:
50
+ logging.error(f"Error initializing model: {str(e)}")
51
+ raise
52
+
53
+ def process_image(self, image):
54
+ """Process image through CLIP"""
55
+ with torch.no_grad():
56
+ image_inputs = self.processor(images=image, return_tensors="pt")
57
+ image_features = self.clip.get_image_features(
58
+ pixel_values=image_inputs.pixel_values.to(self.device)
59
+ )
60
+ return image_features
61
+
62
+ def generate_response(self, message, image=None):
63
+ try:
64
+ if image is not None:
65
+ # Get image features
66
+ image_features = self.process_image(image)
67
+
68
+ # Format prompt
69
+ prompt = f"human: <image>\n{message}\ngpt:"
70
+
71
+ # Add context from history
72
+ context = ""
73
+ for turn in self.history[-3:]:
74
+ context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
75
+
76
+ full_prompt = context + prompt
77
+
78
+ # Prepare text inputs
79
+ inputs = self.tokenizer(
80
+ full_prompt,
81
+ return_tensors="pt",
82
+ padding=True,
83
+ truncation=True,
84
+ max_length=512
85
+ )
86
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
87
+
88
+ # Add image features to inputs
89
+ inputs["image_features"] = image_features
90
+
91
+ # Generate response
92
+ with torch.no_grad():
93
+ outputs = self.model.generate(
94
+ **inputs,
95
+ max_new_tokens=256,
96
+ min_length=20,
97
+ temperature=0.7,
98
+ do_sample=True,
99
+ top_p=0.9,
100
+ top_k=40,
101
+ repetition_penalty=1.5,
102
+ no_repeat_ngram_size=3,
103
+ use_cache=True,
104
+ pad_token_id=self.tokenizer.pad_token_id,
105
+ eos_token_id=self.tokenizer.eos_token_id
106
+ )
107
+ else:
108
+ # Text-only response
109
+ prompt = f"human: {message}\ngpt:"
110
+ context = ""
111
+ for turn in self.history[-3:]:
112
+ context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
113
+
114
+ full_prompt = context + prompt
115
+ inputs = self.tokenizer(
116
+ full_prompt,
117
+ return_tensors="pt",
118
+ padding=True,
119
+ truncation=True,
120
+ max_length=512
121
+ )
122
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
123
+
124
+ with torch.no_grad():
125
+ outputs = self.model.generate(
126
+ **inputs,
127
+ max_new_tokens=150,
128
+ min_length=20,
129
+ temperature=0.6,
130
+ do_sample=True,
131
+ top_p=0.85,
132
+ top_k=30,
133
+ repetition_penalty=1.8,
134
+ no_repeat_ngram_size=4,
135
+ use_cache=True,
136
+ pad_token_id=self.tokenizer.pad_token_id,
137
+ eos_token_id=self.tokenizer.eos_token_id
138
+ )
139
+
140
+ # Decode response
141
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
142
+
143
+ # Clean up response
144
+ if "gpt:" in response:
145
+ response = response.split("gpt:")[-1].strip()
146
+ if "human:" in response:
147
+ response = response.split("human:")[0].strip()
148
+ if "<image>" in response:
149
+ response = response.replace("<image>", "").strip()
150
+
151
+ # Update history
152
+ self.history.append((message, response))
153
+
154
+ return response
155
+
156
+ except Exception as e:
157
+ logging.error(f"Error generating response: {str(e)}")
158
+ logging.error(f"Full traceback:", exc_info=True)
159
+ return f"Error: {str(e)}"
160
+
161
+ def clear_history(self):
162
+ self.history = []
163
+ return None
164
+
165
+ def create_demo():
166
+ # Initialize model
167
+ model = LLaVAPhiModel()
168
+
169
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
170
+ gr.Markdown(
171
+ """
172
+ # LLaVA-Phi Demo
173
+ Chat with a vision-language model that can understand both text and images.
174
+ """
175
+ )
176
+
177
+ chatbot = gr.Chatbot(height=400)
178
+ with gr.Row():
179
+ with gr.Column(scale=0.7):
180
+ msg = gr.Textbox(
181
+ show_label=False,
182
+ placeholder="Enter text and/or upload an image",
183
+ container=False
184
+ )
185
+ with gr.Column(scale=0.15, min_width=0):
186
+ clear = gr.Button("Clear")
187
+ with gr.Column(scale=0.15, min_width=0):
188
+ submit = gr.Button("Submit", variant="primary")
189
+
190
+ image = gr.Image(type="pil", label="Upload Image (Optional)")
191
+
192
+ def respond(message, chat_history, image):
193
+ if not message and image is None:
194
+ return chat_history
195
+
196
+ response = model.generate_response(message, image)
197
+ chat_history.append((message, response))
198
+ return "", chat_history
199
+
200
+ def clear_chat():
201
+ model.clear_history()
202
+ return None, None
203
+
204
+ submit.click(
205
+ respond,
206
+ [msg, chatbot, image],
207
+ [msg, chatbot],
208
+ )
209
+
210
+ clear.click(
211
+ clear_chat,
212
+ None,
213
+ [chatbot, image],
214
+ )
215
+
216
+ msg.submit(
217
+ respond,
218
+ [msg, chatbot, image],
219
+ [msg, chatbot],
220
+ )
221
+
222
+ return demo
223
+
224
+ if __name__ == "__main__":
225
+ demo = create_demo()
226
+ demo.launch(
227
+ server_name="0.0.0.0",
228
+ server_port=7860,
229
+ share=True
230
+ )