Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import torch | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import io | |
# Function to perform mean pooling on the model outputs | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output['last_hidden_state'] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
mean_pooled_embeddings = sum_embeddings / sum_mask | |
return mean_pooled_embeddings | |
# Initialize the pipeline for image-to-text | |
image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
# Initialize tokenizer and model for text processing | |
tokenizer_text = AutoTokenizer.from_pretrained('jim33282007/5240_grp27_proj') | |
model_text = AutoModel.from_pretrained('jim33282007/5240_grp27_proj') | |
# Initialize a text generation model | |
model_gpt2 = AutoModelForCausalLM.from_pretrained('gpt2-xl') | |
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2-xl') | |
st.title('Image Captioning, Text Embedding, Text Generation, and Input Application') | |
# Function to load images from URL | |
def load_image_from_url(url): | |
try: | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
return img | |
except Exception as e: | |
st.error(f"Error loading image from URL: {e}") | |
return None | |
# User option to select input type: Upload, URL, or Type Sentence | |
input_type = st.radio("Select input type:", ("Upload Image", "Image URL", "Type Sentence")) | |
image = None | |
typed_text = "" | |
if input_type == "Upload Image": | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(io.BytesIO(uploaded_file.getvalue())) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
elif input_type == "Image URL": | |
image_url = st.text_input("Enter the image URL here:", "") | |
if image_url: | |
image = load_image_from_url(image_url) | |
if image: | |
st.image(image, caption='Image from URL', use_column_width=True) | |
elif input_type == "Type Sentence": | |
typed_text = st.text_area("Type your sentence here:") | |
# Generate caption and process text button | |
if st.button('Generate Caption and Process Text'): | |
if image or typed_text: | |
with st.spinner("Processing..."): | |
generated_text_p1 = "" | |
if input_type == "Upload Image" and uploaded_file is not None: | |
result = image_to_text(image) | |
generated_text_p1 = result[0]['generated_text'] | |
elif input_type == "Image URL" and image_url: | |
result = image_to_text(image_url) | |
generated_text_p1 = result[0]['generated_text'] | |
elif input_type == "Type Sentence" and typed_text: | |
generated_text_p1 = typed_text | |
if generated_text_p1: | |
st.success(f'Processed Text: {generated_text_p1}') | |
# Generate additional text using GPT-2 based on the processed text | |
input_ids = tokenizer_gpt2.encode(generated_text_p1, return_tensors='pt') | |
generated_outputs = model_gpt2.generate(input_ids, max_length=100, num_return_sequences=1) | |
generated_text = tokenizer_gpt2.decode(generated_outputs[0], skip_special_tokens=True) | |
st.text_area("Generated Text:", generated_text, height=200) | |
else: | |
st.error("Please upload an image, enter an image URL, or type a sentence first.") | |