import streamlit as st from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from PIL import Image import torch import cv2 import tempfile from langchain import LLMChain, PromptTemplate from langchain_community.llms import Ollama from langchain_core.output_parsers import StrOutputParser # Step 1: Load the model def load_model(): st.write("Loading the model...") processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) st.write("Model loaded successfully!") return processor, model, device # Step 2: Upload image or video def upload_media(): return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True) # Step 3: Enter your question def get_user_question(): return st.text_input("Ask a question about the images or videos:") # Process image def process_image(uploaded_file): image = Image.open(uploaded_file) image = image.resize((256,256)) # Reduce size to save memory st.image(image, caption='Uploaded Image.', use_column_width=True) return image # Process video def process_video(uploaded_file): tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_file.read()) cap = cv2.VideoCapture(tfile.name) ret, frame = cap.read() cap.release() if not ret: st.error("Failed to read the video file.") return None image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) image = image.resize((256, 256)) # Reduce size to save memory st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True) return image # Generate description def generate_description(processor, model, device, image, user_question): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": user_question}, ], } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device) generated_ids = model.generate(**inputs, max_new_tokens=512) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) return output_text[0] # Generate story def generate_story(descriptions): combined_text = " ".join(descriptions) prompt_template = PromptTemplate( input_variables=["descriptions"], template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:" ) ollama_llm = Ollama(model="llama3.1") output_parser = StrOutputParser() chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser) return chain.run({"descriptions": combined_text}) # Main function to control the flow def main(): st.title("Media Story Generator") # Step 1: Load the model processor, model, device = load_model() # Step 2: Upload image or video uploaded_files = upload_media() if uploaded_files: # Step 3: Enter your question user_question = get_user_question() if user_question: # Step 4: Generate description st.write("Step 4: Generate description") generate_description_button = st.button("Generate Descriptions", key="generate_descriptions") if generate_description_button: all_output_texts = [] for idx, uploaded_file in enumerate(uploaded_files): file_type = uploaded_file.type.split('/')[0] image = None if file_type == 'image': image = process_image(uploaded_file) elif file_type == 'video': image = process_video(uploaded_file) else: st.error("Unsupported file type.") continue if image: description = generate_description(processor, model, device, image, user_question) st.write(f"Description for file {idx + 1}:") st.write(description) all_output_texts.append(description) # Store descriptions in session state st.session_state["all_output_texts"] = all_output_texts # Check if descriptions are available in session state if "all_output_texts" in st.session_state and st.session_state["all_output_texts"]: st.write("Generate story") generate_story_button = st.button("Generate Story", key="generate_story") if generate_story_button: story = generate_story(st.session_state["all_output_texts"]) st.write("Generated Story:") st.write(story) if __name__ == "__main__": main()