|
import streamlit as st |
|
import os |
|
import zipfile |
|
import tempfile |
|
import base64 |
|
from PIL import Image |
|
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer |
|
import pandas as pd |
|
from nltk.corpus import wordnet |
|
import spacy |
|
import io |
|
from spacy.cli import download |
|
|
|
|
|
download("en_core_web_sm") |
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
import nltk |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
|
|
|
|
model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-11" |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
feature_extractor = ViTImageProcessor.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
model.config.decoder_start_token_id = tokenizer.bos_token_id |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
def generate_caption(image): |
|
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
|
output_ids = model.generate(pixel_values) |
|
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
return caption |
|
|
|
def get_synonyms(word): |
|
synonyms = set() |
|
for syn in wordnet.synsets(word): |
|
for lemma in syn.lemmas(): |
|
synonyms.add(lemma.name()) |
|
return synonyms |
|
|
|
def preprocess_query(query): |
|
doc = nlp(query) |
|
tokens = set() |
|
for token in doc: |
|
tokens.add(token.text.lower()) |
|
tokens.add(token.lemma_.lower()) |
|
tokens.update(get_synonyms(token.text.lower())) |
|
return tokens |
|
|
|
def search_captions(query, captions): |
|
query_tokens = preprocess_query(query) |
|
|
|
results = [] |
|
for path, caption in captions.items(): |
|
caption_tokens = preprocess_query(caption) |
|
if query_tokens & caption_tokens: |
|
results.append((path, caption)) |
|
|
|
return results |
|
|
|
st.title("Image Captioning Gallery") |
|
|
|
|
|
with st.sidebar: |
|
query = st.text_input("Search images by caption:") |
|
|
|
|
|
input_option = st.selectbox("Select input method:", ["Folder Path", "Upload Images", "Upload ZIP"]) |
|
|
|
image_files = [] |
|
|
|
if input_option == "Folder Path": |
|
folder_path = st.text_input("Enter the folder path containing images:") |
|
if folder_path and os.path.isdir(folder_path): |
|
image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] |
|
|
|
elif input_option == "Upload Images": |
|
uploaded_files = st.file_uploader("Upload image files", type=["png", "jpg", "jpeg"], accept_multiple_files=True) |
|
if uploaded_files: |
|
for uploaded_file in uploaded_files: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
image_files.append(temp_file.name) |
|
|
|
elif input_option == "Upload ZIP": |
|
uploaded_zip = st.file_uploader("Upload a ZIP file containing images", type=["zip"]) |
|
if uploaded_zip: |
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
|
temp_file.write(uploaded_zip.read()) |
|
with zipfile.ZipFile(temp_file.name, 'r') as zip_ref: |
|
zip_ref.extractall("/tmp/images") |
|
image_files = [os.path.join("/tmp/images", f) for f in zip_ref.namelist() if f.lower().endswith(('png', 'jpg', 'jpeg'))] |
|
|
|
captions = {} |
|
if st.button("Generate Captions", key='generate_captions'): |
|
for image_file in image_files: |
|
try: |
|
image = Image.open(image_file) |
|
caption = generate_caption(image) |
|
captions[image_file] = caption |
|
except Exception as e: |
|
st.error(f"Error processing {image_file}: {e}") |
|
|
|
|
|
st.subheader("Images and Captions:") |
|
cols = st.columns(4) |
|
idx = 0 |
|
for image_path, caption in captions.items(): |
|
col = cols[idx % 4] |
|
with col: |
|
try: |
|
with open(image_path, "rb") as img_file: |
|
img_bytes = img_file.read() |
|
encoded_image = base64.b64encode(img_bytes).decode() |
|
st.markdown( |
|
f""" |
|
<div style='text-align: center;'> |
|
<img src='data:image/jpeg;base64,{encoded_image}' width='100%'> |
|
<p>{caption}</p> |
|
<p style='font-size: small; font-style: italic;'>{image_path}</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
except Exception as e: |
|
st.error(f"Error displaying {image_path}: {e}") |
|
idx += 1 |
|
|
|
if query: |
|
results = search_captions(query, captions) |
|
st.write("Search Results:") |
|
cols = st.columns(4) |
|
idx = 0 |
|
for image_path, caption in results: |
|
col = cols[idx % 4] |
|
with col: |
|
try: |
|
with open(image_path, "rb") as img_file: |
|
img_bytes = img_file.read() |
|
encoded_image = base64.b64encode(img_bytes).decode() |
|
st.markdown( |
|
f""" |
|
<div style='text-align: center;'> |
|
<img src='data:image/jpeg;base64,{encoded_image}' width='100%'> |
|
<p>{caption}</p> |
|
<p style='font-size: small; font-style: italic;'>{image_path}</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
except Exception as e: |
|
st.error(f"Error displaying search result {image_path}: {e}") |
|
idx += 1 |
|
|
|
|
|
df = pd.DataFrame(list(captions.items()), columns=['Image', 'Caption']) |
|
excel_file = io.BytesIO() |
|
df.to_excel(excel_file, index=False) |
|
excel_file.seek(0) |
|
st.download_button(label="Download captions as Excel", |
|
data=excel_file, |
|
file_name="captions.xlsx", |
|
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") |
|
|