VinitT commited on
Commit
ab2cf62
·
verified ·
1 Parent(s): 41fe60b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -143
app.py CHANGED
@@ -8,146 +8,132 @@ from langchain import LLMChain, PromptTemplate
8
  from langchain_community.llms import Ollama
9
  from langchain_core.output_parsers import StrOutputParser
10
 
11
- # Load the processor and model directly
12
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
13
- model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
14
-
15
- # Check if CUDA is available and set the device accordingly
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model.to(device)
18
-
19
- # Streamlit app
20
- st.title("Media Description Generator")
21
-
22
- uploaded_files = st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
23
- generate_button = st.button("Generate Description")
24
-
25
- if generate_button and uploaded_files:
26
- user_question = st.text_input("Ask a question about the images or videos:")
27
-
28
- if user_question:
29
- all_output_texts = [] # Initialize an empty list to store all output texts
30
-
31
- for uploaded_file in uploaded_files:
32
- file_type = uploaded_file.type.split('/')[0]
33
-
34
- if file_type == 'image':
35
- # Open the image
36
- image = Image.open(uploaded_file)
37
- # Resize image to reduce memory usage
38
- image = image.resize((256, 256)) # Reduce size to save memory
39
- st.image(image, caption='Uploaded Image.', use_column_width=True)
40
- st.write("Generating description...")
41
-
42
- elif file_type == 'video':
43
- # Save the uploaded video to a temporary file
44
- tfile = tempfile.NamedTemporaryFile(delete=False)
45
- tfile.write(uploaded_file.read())
46
-
47
- # Open the video file
48
- cap = cv2.VideoCapture(tfile.name)
49
-
50
- # Extract the first frame
51
- ret, frame = cap.read()
52
- if not ret:
53
- st.error("Failed to read the video file.")
54
- continue
55
- else:
56
- # Convert the frame to an image
57
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
58
- # Resize image to reduce memory usage
59
- image = image.resize((256, 256)) # Reduce size to save memory
60
- st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
61
- st.write("Generating description...")
62
-
63
- # Release the video capture object
64
- cap.release()
65
-
66
- else:
67
- st.error("Unsupported file type.")
68
- continue
69
-
70
- # Ensure the image is loaded correctly
71
- if image is None:
72
- st.error("Failed to load the image.")
73
- continue
74
-
75
- messages = [
76
- {
77
- "role": "user",
78
- "content": [
79
- {
80
- "type": "image",
81
- "image": image,
82
- },
83
- {"type": "text", "text": user_question},
84
- ],
85
- }
86
- ]
87
-
88
- # Preparation for inference
89
- text = processor.apply_chat_template(
90
- messages, tokenize=False, add_generation_prompt=True
91
- )
92
-
93
- # Pass the image to the processor
94
- inputs = processor(
95
- text=[text],
96
- images=[image],
97
- padding=True,
98
- return_tensors="pt",
99
- )
100
- inputs = inputs.to(device) # Ensure inputs are on the same device as the model
101
-
102
- # Inference: Generation of the output
103
- try:
104
- generated_ids = model.generate(**inputs, max_new_tokens=512)
105
- generated_ids_trimmed = [
106
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
107
- ]
108
- output_text = processor.batch_decode(
109
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
110
- )
111
-
112
- st.write("Description:")
113
- st.write(output_text[0])
114
-
115
- # Append the output text to the list
116
- all_output_texts.append(output_text[0])
117
-
118
- except Exception as e:
119
- st.error(f"Error during generation: {e}")
120
- continue
121
-
122
- # Clear memory after processing each file
123
- del image, inputs, generated_ids, generated_ids_trimmed, output_text
124
- torch.cuda.empty_cache()
125
- torch.manual_seed(0) # Reset the seed to ensure reproducibility
126
-
127
- # Combine all descriptions into a single text
128
- combined_text = " ".join(all_output_texts)
129
-
130
- # Create a custom prompt
131
- custom_prompt = f"Based on the following descriptions, create a short story:\n\n{combined_text}\n\nStory:"
132
-
133
- # Define the prompt template for LangChain
134
- prompt_template = PromptTemplate(
135
- input_variables=["descriptions"],
136
- template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
137
- )
138
-
139
- # Create the LLMChain with the Ollama model
140
- ollama_llm = Ollama(model="llama3.1")
141
- output_parser = StrOutputParser()
142
- chain = LLMChain(
143
- llm=ollama_llm,
144
- prompt=prompt_template,
145
- output_parser=output_parser
146
- )
147
-
148
- # Generate the story using LangChain
149
- story = chain.run({"descriptions": combined_text})
150
-
151
- # Display the generated story
152
- st.write("Generated Story:")
153
- st.write(story)
 
8
  from langchain_community.llms import Ollama
9
  from langchain_core.output_parsers import StrOutputParser
10
 
11
+ # Step 1: Load the model
12
+ def load_model():
13
+ st.write("Loading the model...")
14
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
15
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
+ st.write("Model loaded successfully!")
19
+ return processor, model, device
20
+
21
+ # Step 2: Upload image or video
22
+ def upload_media():
23
+ st.write("Step 2: Upload an image or video")
24
+ return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
25
+
26
+ # Step 3: Enter your question
27
+ def get_user_question():
28
+ st.write("Step 3: Enter your question")
29
+ return st.text_input("Ask a question about the images or videos:")
30
+
31
+ # Process image
32
+ def process_image(uploaded_file):
33
+ image = Image.open(uploaded_file)
34
+ image = image.resize((256, 256)) # Reduce size to save memory
35
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
36
+ return image
37
+
38
+ # Process video
39
+ def process_video(uploaded_file):
40
+ tfile = tempfile.NamedTemporaryFile(delete=False)
41
+ tfile.write(uploaded_file.read())
42
+ cap = cv2.VideoCapture(tfile.name)
43
+ ret, frame = cap.read()
44
+ cap.release()
45
+ if not ret:
46
+ st.error("Failed to read the video file.")
47
+ return None
48
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
49
+ image = image.resize((256, 256)) # Reduce size to save memory
50
+ st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
51
+ return image
52
+
53
+ # Generate description
54
+ def generate_description(processor, model, device, image, user_question):
55
+ messages = [
56
+ {
57
+ "role": "user",
58
+ "content": [
59
+ {"type": "image", "image": image},
60
+ {"type": "text", "text": user_question},
61
+ ],
62
+ }
63
+ ]
64
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
65
+ inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
66
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
67
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
68
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
69
+ return output_text[0]
70
+
71
+ # Generate story
72
+ def generate_story(descriptions):
73
+ combined_text = " ".join(descriptions)
74
+ prompt_template = PromptTemplate(
75
+ input_variables=["descriptions"],
76
+ template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
77
+ )
78
+ ollama_llm = Ollama(model="llama3.1")
79
+ output_parser = StrOutputParser()
80
+ chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
81
+ return chain.run({"descriptions": combined_text})
82
+
83
+ # Main function to control the flow
84
+ def main():
85
+ st.title("Media Description Generator")
86
+
87
+ # Step 1: Load the model
88
+ processor, model, device = load_model()
89
+
90
+ # Step 2: Upload image or video
91
+ uploaded_files = upload_media()
92
+
93
+ if uploaded_files:
94
+ # Step 3: Enter your question
95
+ user_question = get_user_question()
96
+
97
+ if user_question:
98
+ # Step 4: Generate description
99
+ st.write("Step 4: Generate description")
100
+ generate_description_button = st.button("Generate Description")
101
+
102
+ if generate_description_button:
103
+ all_output_texts = []
104
+
105
+ for uploaded_file in uploaded_files:
106
+ file_type = uploaded_file.type.split('/')[0]
107
+ image = None
108
+
109
+ if file_type == 'image':
110
+ image = process_image(uploaded_file)
111
+ elif file_type == 'video':
112
+ image = process_video(uploaded_file)
113
+ else:
114
+ st.error("Unsupported file type.")
115
+ continue
116
+
117
+ if image:
118
+ description = generate_description(processor, model, device, image, user_question)
119
+ st.write("Description:")
120
+ st.write(description)
121
+ all_output_texts.append(description)
122
+
123
+ # Clear memory after processing each file
124
+ del image
125
+ torch.cuda.empty_cache()
126
+ torch.manual_seed(0)
127
+
128
+ if all_output_texts:
129
+ # Step 5: Generate story
130
+ st.write("Step 5: Generate story")
131
+ generate_story_button = st.button("Generate Story")
132
+
133
+ if generate_story_button:
134
+ story = generate_story(all_output_texts)
135
+ st.write("Generated Story:")
136
+ st.write(story)
137
+
138
+ if __name__ == "__main__":
139
+ main()