VinitT's picture
Update app.py
d027b05 verified
raw
history blame
5.71 kB
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
# Load the processor and model directly
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Streamlit app
st.title("Media Description Generator")
uploaded_files = st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
if uploaded_files:
user_question = st.text_input("Ask a question about the images or videos:")
if user_question:
all_output_texts = [] # Initialize an empty list to store all output texts
for uploaded_file in uploaded_files:
file_type = uploaded_file.type.split('/')[0]
if file_type == 'image':
# Open the image
image = Image.open(uploaded_file)
# Resize image to reduce memory usage
image = image.resize((512, 512))
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("Generating description...")
elif file_type == 'video':
# Save the uploaded video to a temporary file
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
# Open the video file
cap = cv2.VideoCapture(tfile.name)
# Extract the first frame
ret, frame = cap.read()
if not ret:
st.error("Failed to read the video file.")
continue
else:
# Convert the frame to an image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Resize image to reduce memory usage
image = image.resize((512, 512))
st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
st.write("Generating description...")
# Release the video capture object
cap.release()
else:
st.error("Unsupported file type.")
continue
# Ensure the image is loaded correctly
if image is None:
st.error("Failed to load the image.")
continue
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": user_question},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Pass the image to the processor
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device) # Ensure inputs are on the same device as the model
# Inference: Generation of the output
try:
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
st.write("Description:")
st.write(output_text[0])
# Append the output text to the list
all_output_texts.append(output_text[0])
except Exception as e:
st.error(f"Error during generation: {e}")
continue
# Clear memory after processing each file
del image, inputs, generated_ids, generated_ids_trimmed, output_text
torch.cuda.empty_cache()
torch.manual_seed(0) # Reset the seed to ensure reproducibility
# Combine all descriptions into a single text
combined_text = " ".join(all_output_texts)
# Create a custom prompt
custom_prompt = f"Based on the following descriptions, create a short story:\n\n{combined_text}\n\nStory:"
# Define the prompt template for LangChain
prompt_template = PromptTemplate(
input_variables=["descriptions"],
template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
)
# Create the LLMChain with the Ollama model
ollama_llm = Ollama(model="llama3.1")
output_parser = StrOutputParser()
chain = LLMChain(
llm=ollama_llm,
prompt=prompt_template,
output_parser=output_parser
)
# Generate the story using LangChain
story = chain.run({"descriptions": combined_text})
# Display the generated story
st.write("Generated Story:")
st.write(story)