VinitT's picture
Update app.py
8fc60be verified
raw
history blame
5.15 kB
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
# Step 1: Load the model
def load_model():
st.write("Loading the model...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
st.write("Model loaded successfully!")
return processor, model, device
# Step 2: Upload image or video
def upload_media():
return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
# Step 3: Enter your question
def get_user_question():
return st.text_input("Ask a question about the images or videos:")
# Process image
def process_image(uploaded_file):
image = Image.open(uploaded_file)
image = image.resize((512, 512)) # Reduce size to save memory
st.image(image, caption='Uploaded Image.', use_column_width=True)
return image
# Process video
def process_video(uploaded_file):
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
ret, frame = cap.read()
cap.release()
if not ret:
st.error("Failed to read the video file.")
return None
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = image.resize((256, 256)) # Reduce size to save memory
st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
return image
# Generate description
def generate_description(processor, model, device, image, user_question):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": user_question},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return output_text[0]
# Generate story
def generate_story(descriptions):
combined_text = " ".join(descriptions)
prompt_template = PromptTemplate(
input_variables=["descriptions"],
template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
)
ollama_llm = Ollama(model="llama3.1")
output_parser = StrOutputParser()
chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
return chain.run({"descriptions": combined_text})
# Main function to control the flow
def main():
st.title("Media Story Generator")
all_output_texts = []
# Step 1: Load the model
processor, model, device = load_model()
# Step 2: Upload image or video
uploaded_files = upload_media()
if uploaded_files:
# Step 3: Enter your question
user_question = get_user_question()
if user_question:
# Step 4: Generate description
st.write("Step 4: Generate description")
generate_description_button = st.button("Generate Description")
if generate_description_button:
for uploaded_file in uploaded_files:
file_type = uploaded_file.type.split('/')[0]
image = None
if file_type == 'image':
image = process_image(uploaded_file)
elif file_type == 'video':
image = process_video(uploaded_file)
else:
st.error("Unsupported file type.")
continue
if image:
description = generate_description(processor, model, device, image, user_question)
st.write("Description:")
st.write(description)
all_output_texts.append(description)
# Clear memory after processing each file
del image
torch.cuda.empty_cache()
torch.manual_seed(0)
if all_output_texts:
# Step 5: Generate story
generate_story_button = st.button("Generate Story")
if generate_story_button:
story = generate_story(all_output_texts)
st.write("Generated Story:")
st.write(story)
if __name__ == "__main__":
main()