sagar007 commited on
Commit
ab8bcac
·
verified ·
1 Parent(s): 5904b1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -14,23 +14,22 @@ TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
14
  VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  # Load models and tokenizers
19
  text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
20
  text_model = AutoModelForCausalLM.from_pretrained(
21
  TEXT_MODEL_ID,
22
- torch_dtype=torch.float32 if device == "cpu" else torch.float16,
23
- device_map="auto" if device == "cuda" else None,
24
  low_cpu_mem_usage=True
25
  )
26
 
27
- if device == "cuda":
28
- text_model = text_model.half() # Convert to half precision if on GPU
29
-
30
  vision_model = AutoModelForCausalLM.from_pretrained(
31
  VISION_MODEL_ID,
32
  trust_remote_code=True,
33
- torch_dtype=torch.float32 if device == "cpu" else torch.float16,
 
34
  low_cpu_mem_usage=True
35
  ).to(device).eval()
36
 
@@ -46,7 +45,7 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
46
  ])
47
  conversation.append({"role": "user", "content": message})
48
 
49
- input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
50
  streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
51
 
52
  generate_kwargs = dict(
@@ -116,5 +115,4 @@ with gr.Blocks() as demo:
116
  vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
117
 
118
  if __name__ == "__main__":
119
- print(f"Running on device: {device}")
120
  demo.launch()
 
14
  VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Using device: {device}")
18
 
19
  # Load models and tokenizers
20
  text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
21
  text_model = AutoModelForCausalLM.from_pretrained(
22
  TEXT_MODEL_ID,
23
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
24
+ device_map="auto",
25
  low_cpu_mem_usage=True
26
  )
27
 
 
 
 
28
  vision_model = AutoModelForCausalLM.from_pretrained(
29
  VISION_MODEL_ID,
30
  trust_remote_code=True,
31
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
32
+ attn_implementation="flash_attention_2" if device == "cuda" else None,
33
  low_cpu_mem_usage=True
34
  ).to(device).eval()
35
 
 
45
  ])
46
  conversation.append({"role": "user", "content": message})
47
 
48
+ input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
49
  streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
50
 
51
  generate_kwargs = dict(
 
115
  vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
116
 
117
  if __name__ == "__main__":
 
118
  demo.launch()