File size: 5,152 Bytes
fd19cdd
6255a7a
fd19cdd
 
36d8cb0
 
6255a7a
 
 
fd19cdd
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d662e5
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d662e5
 
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fc60be
 
 
 
 
 
 
 
ab2cf62
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser

# Step 1: Load the model
def load_model():
    st.write("Loading the model...")
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    st.write("Model loaded successfully!")
    return processor, model, device

# Step 2: Upload image or video
def upload_media():
    return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)

# Step 3: Enter your question
def get_user_question():
    return st.text_input("Ask a question about the images or videos:")

# Process image
def process_image(uploaded_file):
    image = Image.open(uploaded_file)
    image = image.resize((512, 512))  # Reduce size to save memory
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    return image

# Process video
def process_video(uploaded_file):
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(uploaded_file.read())
    cap = cv2.VideoCapture(tfile.name)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        st.error("Failed to read the video file.")
        return None
    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    image = image.resize((256, 256))  # Reduce size to save memory
    st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
    return image

# Generate description
def generate_description(processor, model, device, image, user_question):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": user_question},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return output_text[0]

# Generate story
def generate_story(descriptions):
    combined_text = " ".join(descriptions)
    prompt_template = PromptTemplate(
        input_variables=["descriptions"],
        template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
    )
    ollama_llm = Ollama(model="llama3.1")
    output_parser = StrOutputParser()
    chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
    return chain.run({"descriptions": combined_text})

# Main function to control the flow
def main():
    st.title("Media Story Generator")
    all_output_texts = []

    # Step 1: Load the model
    processor, model, device = load_model()

    # Step 2: Upload image or video
    uploaded_files = upload_media()

    if uploaded_files:
        # Step 3: Enter your question
        user_question = get_user_question()

        if user_question:
            # Step 4: Generate description
            st.write("Step 4: Generate description")
            generate_description_button = st.button("Generate Description")

            if generate_description_button:

                for uploaded_file in uploaded_files:
                    file_type = uploaded_file.type.split('/')[0]
                    image = None

                    if file_type == 'image':
                        image = process_image(uploaded_file)
                    elif file_type == 'video':
                        image = process_video(uploaded_file)
                    else:
                        st.error("Unsupported file type.")
                        continue

                    if image:
                        description = generate_description(processor, model, device, image, user_question)
                        st.write("Description:")
                        st.write(description)
                        all_output_texts.append(description)

                    # Clear memory after processing each file
                    del image
                    torch.cuda.empty_cache()
                    torch.manual_seed(0)

                if all_output_texts:
                    # Step 5: Generate story
                    generate_story_button = st.button("Generate Story")

                    if generate_story_button:
                        story = generate_story(all_output_texts)
                        st.write("Generated Story:")
                        st.write(story)

if __name__ == "__main__":
    main()