greencatted commited on
Commit
0bef8ce
1 Parent(s): 4a41b58

Use Llama Vision Instruct

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -1,20 +1,37 @@
1
  import streamlit as st
2
  from PIL import Image
3
 
4
- from transformers import BlipProcessor, BlipForConditionalGeneration
 
5
 
6
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
7
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
 
 
 
 
 
 
8
 
9
  enable = st.checkbox("Enable camera")
10
  picture = st.camera_input("Take a picture", disabled=not enable)
11
 
12
  if picture:
13
- raw_image = Image.open(picture)
14
 
15
- # conditional image captioning
16
- text = "A view of a person in"
17
- inputs = processor(raw_image, text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
18
 
19
- out = model.generate(**inputs)
20
- st.write(processor.decode(out[0], skip_special_tokens=True))
 
1
  import streamlit as st
2
  from PIL import Image
3
 
4
+ import torch
5
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
6
 
7
+ model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
8
+
9
+ model = MllamaForConditionalGeneration.from_pretrained(
10
+ model_id,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="auto",
13
+ )
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
 
16
  enable = st.checkbox("Enable camera")
17
  picture = st.camera_input("Take a picture", disabled=not enable)
18
 
19
  if picture:
20
+ image = Image.open(picture)
21
 
22
+ messages = [
23
+ {"role": "user", "content": [
24
+ {"type": "image"},
25
+ {"type": "text", "text": "Provide your best guess as to where this person is holding his online meeting. Just state your guess of location in your response."}
26
+ ]}
27
+ ]
28
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
29
+ inputs = processor(
30
+ image,
31
+ input_text,
32
+ add_special_tokens=False,
33
+ return_tensors="pt"
34
+ ).to(model.device)
35
 
36
+ output = model.generate(**inputs, max_new_tokens=30)
37
+ print(processor.decode(output[0]))