Spaces:
Sleeping
Sleeping
File size: 5,152 Bytes
fd19cdd 6255a7a fd19cdd 36d8cb0 6255a7a fd19cdd ab2cf62 6d662e5 ab2cf62 6d662e5 ab2cf62 8fc60be ab2cf62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
# Step 1: Load the model
def load_model():
st.write("Loading the model...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
st.write("Model loaded successfully!")
return processor, model, device
# Step 2: Upload image or video
def upload_media():
return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
# Step 3: Enter your question
def get_user_question():
return st.text_input("Ask a question about the images or videos:")
# Process image
def process_image(uploaded_file):
image = Image.open(uploaded_file)
image = image.resize((512, 512)) # Reduce size to save memory
st.image(image, caption='Uploaded Image.', use_column_width=True)
return image
# Process video
def process_video(uploaded_file):
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
ret, frame = cap.read()
cap.release()
if not ret:
st.error("Failed to read the video file.")
return None
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = image.resize((256, 256)) # Reduce size to save memory
st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
return image
# Generate description
def generate_description(processor, model, device, image, user_question):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": user_question},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return output_text[0]
# Generate story
def generate_story(descriptions):
combined_text = " ".join(descriptions)
prompt_template = PromptTemplate(
input_variables=["descriptions"],
template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
)
ollama_llm = Ollama(model="llama3.1")
output_parser = StrOutputParser()
chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
return chain.run({"descriptions": combined_text})
# Main function to control the flow
def main():
st.title("Media Story Generator")
all_output_texts = []
# Step 1: Load the model
processor, model, device = load_model()
# Step 2: Upload image or video
uploaded_files = upload_media()
if uploaded_files:
# Step 3: Enter your question
user_question = get_user_question()
if user_question:
# Step 4: Generate description
st.write("Step 4: Generate description")
generate_description_button = st.button("Generate Description")
if generate_description_button:
for uploaded_file in uploaded_files:
file_type = uploaded_file.type.split('/')[0]
image = None
if file_type == 'image':
image = process_image(uploaded_file)
elif file_type == 'video':
image = process_video(uploaded_file)
else:
st.error("Unsupported file type.")
continue
if image:
description = generate_description(processor, model, device, image, user_question)
st.write("Description:")
st.write(description)
all_output_texts.append(description)
# Clear memory after processing each file
del image
torch.cuda.empty_cache()
torch.manual_seed(0)
if all_output_texts:
# Step 5: Generate story
generate_story_button = st.button("Generate Story")
if generate_story_button:
story = generate_story(all_output_texts)
st.write("Generated Story:")
st.write(story)
if __name__ == "__main__":
main() |