|
from transformers import Blip2ForConditionalGeneration |
|
from transformers import Blip2Processor |
|
from peft import PeftModel |
|
import streamlit as st |
|
from PIL import Image |
|
|
|
import os |
|
|
|
preprocess_ckp = "Salesforce/blip2-opt-2.7b" |
|
base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" |
|
peft_model_ckp = "./model/blip2_peft" |
|
sample_img_path = "./sample_images" |
|
|
|
map_sampleid_name = { |
|
'dress' : '00fe223d-9d1f-4bd3-a556-7ece9d28e6fb.jpeg', |
|
'earrings': '0b3862ae-f89e-419c-bc1e-57418abd4180.jpeg', |
|
'sweater': '0c21ba7b-ceb6-4136-94a4-1d4394499986.jpeg', |
|
'sunglasses': '0e44ec10-e53b-473a-a77f-ac8828bb5e01.jpeg', |
|
'shoe': '4cd37d6d-e7ea-4c6e-aab2-af700e480bc1.jpeg', |
|
'hat': '69aeb517-c66c-47b8-af7d-bdf1fde57ed0.jpeg', |
|
'heels':'447abc42-6ac7-4458-a514-bdcd570b1cd1.jpeg', |
|
'socks': 'd188836c-b734-4031-98e5-423d5ff1239d.jpeg', |
|
'tee': 'e2d8637a-5478-429d-a2a8-3d5859dbc64d.jpeg', |
|
'bracelet': 'e78518ac-0f54-4483-a233-fad6511f0b86.jpeg' |
|
} |
|
|
|
|
|
def init_model(): |
|
|
|
|
|
|
|
|
|
processor = Blip2Processor.from_pretrained(preprocess_ckp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp) |
|
|
|
model = PeftModel.from_pretrained(model, peft_model_ckp) |
|
|
|
|
|
|
|
return processor, model |
|
|
|
def main(): |
|
|
|
st.title("Fashion Image Caption using BLIP2") |
|
|
|
processor, model = init_model() |
|
|
|
|
|
st.text("Select image:") |
|
option = st.selectbox('From sample', ('None', 'dress', 'earrings', 'sweater', 'sunglasses', 'shoe', 'hat', 'heels', 'socks', 'tee', 'bracelet'), index = 0) |
|
st.text("OR") |
|
file_name = st.file_uploader("Upload an image") |
|
|
|
image = None |
|
if file_name is not None: |
|
|
|
image = Image.open(file_name) |
|
|
|
elif option is not 'None': |
|
|
|
file_name = os.path.join(sample_img_path, map_sampleid_name[option]) |
|
image = Image.open(file_name) |
|
|
|
if image is not None: |
|
|
|
image_col, caption_text = st.columns(2) |
|
image_col.header("Image") |
|
image_col.image(image, use_column_width = True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = processor(images = image, return_tensors = "pt") |
|
|
|
pixel_values = inputs.pixel_values |
|
|
|
|
|
generated_ids = model.generate(pixel_values = pixel_values, max_length = 25) |
|
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
caption_text.header("Generated Caption") |
|
caption_text.text(generated_caption) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |