import streamlit as st from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from PIL import Image import torch import cv2 import tempfile from langchain import LLMChain, PromptTemplate from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser # Load the processor and model directly processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") # Check if CUDA is available and set the device accordingly device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Streamlit app st.title("Media Description Generator") uploaded_files = st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True) if uploaded_files: user_question = st.text_input("Ask a question about the images or videos:") if user_question: all_output_texts = [] # Initialize an empty list to store all output texts for uploaded_file in uploaded_files: file_type = uploaded_file.type.split('/')[0] if file_type == 'image': # Open the image image = Image.open(uploaded_file) # Resize image to reduce memory usage image = image.resize((512, 512)) st.image(image, caption='Uploaded Image.', use_column_width=True) st.write("Generating description...") elif file_type == 'video': # Save the uploaded video to a temporary file tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_file.read()) # Open the video file cap = cv2.VideoCapture(tfile.name) # Extract the first frame ret, frame = cap.read() if not ret: st.error("Failed to read the video file.") continue else: # Convert the frame to an image image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Resize image to reduce memory usage image = image.resize((512, 512)) st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True) st.write("Generating description...") # Release the video capture object cap.release() else: st.error("Unsupported file type.") continue # Ensure the image is loaded correctly if image is None: st.error("Failed to load the image.") continue messages = [ { "role": "user", "content": [ { "type": "image", "image": image, }, {"type": "text", "text": user_question}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Pass the image to the processor inputs = processor( text=[text], images=[image], padding=True, return_tensors="pt", ) inputs = inputs.to(device) # Ensure inputs are on the same device as the model # Inference: Generation of the output try: generated_ids = model.generate(**inputs, max_new_tokens=512) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) st.write("Description:") st.write(output_text[0]) # Append the output text to the list all_output_texts.append(output_text[0]) except Exception as e: st.error(f"Error during generation: {e}") continue # Clear memory after processing each file del image, inputs, generated_ids, generated_ids_trimmed, output_text torch.cuda.empty_cache() torch.manual_seed(0) # Reset the seed to ensure reproducibility # Combine all descriptions into a single text combined_text = " ".join(all_output_texts) # Create a custom prompt custom_prompt = f"Based on the following descriptions, create a short story:\n\n{combined_text}\n\nStory:" # Define the prompt template for LangChain prompt_template = PromptTemplate( input_variables=["descriptions"], template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:" ) # Create the LLMChain with the Ollama model ollama_llm = Ollama(model="llama3.1") output_parser = StrOutputParser() chain = LLMChain( llm=ollama_llm, prompt=prompt_template, output_parser=output_parser ) # Generate the story using LangChain story = chain.run({"descriptions": combined_text}) # Display the generated story st.write("Generated Story:") st.write(story)