Dhanush S Gowda
Update app.py
d5166eb verified
import torch
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
)
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import cv2
from dotenv import load_dotenv
import os
import torchvision.models as models
from torchvision import transforms
import warnings
import google.generativeai as genai
import streamlit as st
from MiDaS.midas.midas_net import MidasNet
import io
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_dotenv()
API_KEY = os.getenv('API_KEY')
def check_overlap(bbox1, bbox2):
"""
Checks if two bounding boxes overlap.
Each bbox is (x_min, y_min, x_max, y_max)
"""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
# If one rectangle is on left side of other
if x1_max < x2_min or x2_max < x1_min:
return False
# If one rectangle is above other
if y1_max < y2_min or y2_max < y1_min:
return False
return True
def draw_bounding_boxes(image: Image.Image, objects: list) -> Image.Image:
"""Draws bounding boxes and labels on the image."""
image_draw = image.copy()
draw_obj = ImageDraw.Draw(image_draw)
font = ImageFont.load_default()
for obj in objects:
x_min, y_min, x_max, y_max = obj["bounding_box"]
label = obj["label"]
caption = obj.get("caption", "")
depth_position = obj.get("depth_position", "Unknown")
color = "red" if depth_position == "Front" else "blue"
draw_obj.rectangle([x_min, y_min, x_max, y_max], outline=color, width=2)
text = f"{label}: {caption}"
try:
text_width = draw_obj.textlength(text, font=font)
ascent, descent = font.getmetrics()
text_height = ascent + descent
except Exception as e:
print(f"Error calculating text size for '{text}': {e}")
text_width, text_height = 0, 0
if text_height <= 0:
text_height = 10
draw_obj.rectangle(
[x_min, y_min - text_height, x_min + text_width, y_min],
fill=color,
)
draw_obj.text(
(x_min, y_min - text_height),
text,
fill="white",
font=font,
)
return image_draw
class ImagePreprocessor:
"""Preprocesses the input image for further analysis."""
@staticmethod
def preprocess(image: Image.Image) -> Image.Image:
"""Applies preprocessing steps to the image."""
# Convert to RGB
image = image.convert("RGB")
# Resize image if necessary
max_size = 800
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.LANCZOS)
return image
class ObjectDetector:
"""Detects objects in images using YOLOv8."""
def __init__(self, model_path: str, confidence_threshold: float = 0.5):
self.model = YOLO(model_path)
self.confidence_threshold = confidence_threshold
print(f"Initialized ObjectDetector with model: {model_path}")
def detect_objects(self, image: Image.Image) -> list:
"""Detects objects in the image and returns bounding boxes and labels."""
results = self.model.predict(
source=np.array(image), conf=self.confidence_threshold, verbose=False
)
detections = []
for result in results[0].boxes:
x_min, y_min, x_max, y_max = result.xyxy[0].cpu().numpy()
conf = result.conf.cpu().item()
cls = int(result.cls.cpu().item())
label = self.model.names[cls]
detections.append(
{
"label": label,
"bounding_box": (
int(x_min),
int(y_min),
int(x_max),
int(y_max),
),
"confidence": conf,
}
)
print(
f"Detected {label} with confidence {conf:.2f} at [{int(x_min)}, {int(y_min)}, {int(x_max)}, {int(y_max)}]"
)
return detections
def draw_detections(self, image: Image.Image, detections: list) -> Image.Image:
"""Draws bounding boxes and labels on the image."""
draw = ImageDraw.Draw(image)
for detection in detections:
x_min, y_min, x_max, y_max = detection["bounding_box"]
label = detection["label"]
confidence = detection["confidence"]
# Draw bounding box
draw.rectangle(
[(x_min, y_min), (x_max, y_max)],
outline="red",
width=3,
)
text = f"{label} ({confidence:.2f})"
text_position = (x_min, y_min - 10)
draw.rectangle(
[text_position, (x_min + len(text) * 6, y_min)],
fill="red",
)
draw.text(
text_position,
text,
fill="white",
)
return image
class DepthEstimator:
"""Estimates depth maps for images using a locally saved MiDaS v2.1 Small model."""
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
print(f"Initialized DepthEstimator with model path: {model_path}")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.model = MidasNet(model_path, non_negative=True).to(self.device)
self.model.eval()
self.transform = self._get_transform()
def _get_transform(self):
"""Defines the transformation pipeline for input images."""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def get_depth_map(self, image: Image.Image) -> np.ndarray:
"""Generates a normalized depth map for the given image."""
# Apply the transformation to the input image
input_batch = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Perform depth estimation
prediction = self.model(input_batch)
# Resize the prediction to match the original image size
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
depth_map = prediction.cpu().numpy()
# Normalize the depth map to the range [0, 1]
normalized_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
return normalized_depth
class EnvironmentClassifier:
"""Classifies the overall environment using Places365 pre-trained ResNet50."""
def __init__(self):
model_file = "resnet50_places365.pth.tar"
print(f"Initialized EnvironmentClassifier with model: {model_file}")
self.model = models.resnet50(num_classes=365)
checkpoint = torch.load(model_file, map_location=device)
state_dict = {
str.replace(k, "module.", ""): v
for k, v in checkpoint["state_dict"].items()
}
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(device)
# image transformer
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
# load the class labels
file_name = "categories_places365.txt"
self.classes = []
with open(file_name) as class_file:
for line in class_file:
self.classes.append(line.strip().split(" ")[0][3:])
self.classes = tuple(self.classes)
def classify_environment(self, image: Image.Image) -> str:
"""Classifies the environment of the image."""
input_img = self.transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = self.model(input_img)
probs = torch.nn.functional.softmax(logits, dim=1)
top_prob, top_idx = torch.topk(probs, 1)
label = self.classes[top_idx.item()]
return label
class ObjectCaptioner:
"""Generates captions for given image crops using BLIP model."""
def __init__(self):
model_name = "Salesforce/blip-image-captioning-base"
print(f"Initialized ObjectCaptioner with model: {model_name}")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(
model_name
).to(device)
def generate_caption(self, image: Image.Image) -> str:
"""Generates a caption for a given image crop."""
inputs = self.processor(image, return_tensors="pt").to(device)
output = self.model.generate(
**inputs,
max_new_tokens=50,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True,
)
caption = self.processor.decode(output[0], skip_special_tokens=True)
return caption
def generate_story_with_genre(prompt):
"""
Takes inputs and calls the Gemini API with a specific genre for the story.
"""
genai.configure(api_key=API_KEY)
try:
model = genai.GenerativeModel("gemini-1.5-flash")
generation_config = {
"temperature": 0.3,
}
response = model.generate_content(contents=prompt, generation_config=generation_config)
return response.text
except Exception as e:
return f"Error occurred during the API call: {str(e)}"
def construct_gemini_prompt(environment, objects, overlapping_regions, back_caption):
"""
Constructs a prompt for Gemini using the outputs of the models.
"""
prompt = f"The scene is classified as a '{environment}'.\n"
prompt += "In the scene, the following objects are present:\n"
for obj in objects:
if obj["depth_position"] == "Front":
label = obj["label"]
caption = obj.get("caption", "")
prompt += f"- {label}: {caption}\n"
if overlapping_regions:
prompt += "\nThere are overlapping regions involving the following objects:\n"
for region in overlapping_regions:
labels = [obj["label"] for obj in region["objects"]]
caption = region.get("caption", "")
prompt += f"- Overlapping {', '.join(labels)}: {caption}\n"
if back_caption:
prompt += "\nThe background of the scene includes:\n"
prompt += f"- {back_caption}\n"
prompt += "\nPlease create a realistic-like story that describes the scene, incorporating the objects and their interactions only strictly."
return prompt
def main():
"""Main function to process the image and generate descriptions in a Streamlit app."""
st.set_page_config(page_title="The Art of Visual Storytelling", layout="wide")
st.title("The Art of Visual Storytelling...")
st.subheader("Convert an image to a story")
# Sidebar for file upload and genre selection
with st.sidebar:
st.header("Upload and Settings")
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
genre = st.selectbox(
"Select Genre",
["Normal", "Adventure", "Fantasy", "Sci-Fi", "Mystery", "Romance", "Horror", "Comedy"]
)
if image_file:
image = Image.open(image_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Button to trigger the story generation
if st.button("Generate Story"):
with st.spinner("Processing..."):
image_path = "input_image.jpg"
image.save(image_path)
try:
main_with_image(image_path, genre)
except Exception as e:
st.error(f"Error: {e}")
def main_with_image(image_path: str, genre: str):
"""Main processing function from original code, adapted for Streamlit."""
if not os.path.exists(image_path):
st.error(f"Error: Image path '{image_path}' does not exist.")
return
# Set up components
try:
preprocessor = ImagePreprocessor()
image = Image.open(image_path).convert("RGB")
image = preprocessor.preprocess(image)
object_detector = ObjectDetector(model_path="yolov8n.pt", confidence_threshold=0.5)
depth_estimator = DepthEstimator(model_path="midas_v21_384.pt")
environment_classifier = EnvironmentClassifier()
captioner = ObjectCaptioner()
except Exception as e:
st.error(f"Error initializing components: {e}")
return
# Environment Classification
with st.expander("Step 1: Classify the Environment"):
try:
environment = environment_classifier.classify_environment(image)
st.write(f"Environment: {environment}")
except Exception as e:
st.error(f"Error classifying environment: {e}")
# Depth Estimation
with st.expander("Step 2: Depth Estimation"):
try:
depth_map = depth_estimator.get_depth_map(image)
st.write(f"Depth map generated with shape {depth_map.shape}, min {depth_map.min():.2f}, max {depth_map.max():.2f}")
buffer = io.BytesIO()
plt.imsave(buffer, depth_map, cmap="viridis", format="png")
buffer.seek(0)
# Display the depth map
st.image(buffer, caption="Depth Map", use_container_width=True)
except Exception as e:
st.error(f"Error estimating depth: {e}")
# Object Detection
with st.expander("Step 3: Object Detection"):
try:
detections = object_detector.detect_objects(image)
st.write(f"Detected {len(detections)} objects.")
for detection in detections:
st.write(f"- {detection['label']} at {detection['bounding_box']}")
annotated_image = object_detector.draw_detections(image.copy(), detections)
# Display the annotated image
st.image(annotated_image, caption="Detected Objects", use_container_width=True)
except Exception as e:
st.error(f"Error detecting objects: {e}")
# Filter Objects Based on Depth
with st.expander("Step 4: Process Objects Based on Depth"):
filtered_objects = []
if depth_map is not None and len(detections) > 0:
for detection in detections:
x_min, y_min, x_max, y_max = detection["bounding_box"]
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(image.width - 1, x_max)
y_max = min(image.height - 1, y_max)
object_depth = depth_map[y_min:y_max, x_min:x_max]
average_depth = np.mean(object_depth) if object_depth.size > 0 else 0.0
depth_position = "Front" if average_depth >= 0.5 else "Back"
detection["depth_position"] = depth_position
detection["average_depth"] = average_depth
filtered_objects.append(detection)
st.write(f"Object {detection['label']} has average depth {average_depth:.2f}, depth position: {depth_position}")
else:
st.warning("No depth map or detections available.")
for detection in detections:
detection["depth_position"] = "Unknown"
detection["average_depth"] = None
filtered_objects.append(detection)
# Handle Overlapping Bounding Boxes
with st.expander("Step 5: Find Overlapping Bounding Boxes"):
try:
overlapping_regions = []
num_objects = len(filtered_objects)
for i in range(num_objects):
obj1 = filtered_objects[i]
for j in range(i + 1, num_objects):
obj2 = filtered_objects[j]
if obj1["depth_position"] == "Front" and obj2["depth_position"] == "Front":
bbox1 = obj1["bounding_box"]
bbox2 = obj2["bounding_box"]
overlap = check_overlap(bbox1, bbox2)
if overlap:
# Compute overlapping region
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
x_overlap_min = max(x1_min, x2_min)
y_overlap_min = max(y1_min, y2_min)
x_overlap_max = min(x1_max, x2_max)
y_overlap_max = min(y1_max, y2_max)
overlapping_region = {
"bounding_box": (
x_overlap_min,
y_overlap_min,
x_overlap_max,
y_overlap_max,
),
"objects": [obj1, obj2],
"depth_position": obj1["depth_position"],
}
overlapping_regions.append(overlapping_region)
st.write(f"Found {len(overlapping_regions)} overlapping regions.")
except Exception as e:
st.error(f"Error processing overlapping bounding boxes: {e}")
# Captioning of Objects in 'Front' and Overlapping Regions
with st.expander("Step 6: Caption Objects"):
front_objects = [obj for obj in filtered_objects if obj["depth_position"] == "Front"]
back_objects = [obj for obj in filtered_objects if obj["depth_position"] == "Back"]
# Caption objects in the Front
for obj in front_objects:
try:
x_min, y_min, x_max, y_max = obj["bounding_box"]
cropped_image = image.crop((x_min, y_min, x_max, y_max))
obj["caption"] = captioner.generate_caption(cropped_image)
st.write(f"- {obj['label']}: {obj['caption']}")
except Exception as e:
obj["caption"] = ""
st.error(f"Error generating caption for {obj['label']}: {e}")
# Caption overlapping regions
for region in overlapping_regions:
x_min, y_min, x_max, y_max = region["bounding_box"]
try:
cropped_image = image.crop((x_min, y_min, x_max, y_max))
caption = captioner.generate_caption(cropped_image)
region["caption"] = caption
labels = [obj["label"] for obj in region["objects"]]
st.write(f"- Overlapping {labels}: {caption}")
except Exception as e:
region["caption"] = ""
st.error(f"Error generating caption for overlapping region: {e}")
# Generate a common caption for 'Back' objects
back_caption = ""
if back_objects:
x_mins = [obj["bounding_box"][0] for obj in back_objects]
y_mins = [obj["bounding_box"][1] for obj in back_objects]
x_maxs = [obj["bounding_box"][2] for obj in back_objects]
y_maxs = [obj["bounding_box"][3] for obj in back_objects]
x_min = min(x_mins)
y_min = min(y_mins)
x_max = max(x_maxs)
y_max = max(y_maxs)
try:
cropped_image = image.crop((x_min, y_min, x_max, y_max))
back_caption = captioner.generate_caption(cropped_image)
st.write(f"- Back objects caption: {back_caption}")
except Exception as e:
st.error(f"Error generating caption for Back objects: {e}")
# Construct prompt for Gemini
with st.expander("Step 7: Construct Prompt for Gemini"):
gemini_prompt = construct_gemini_prompt(environment, filtered_objects, overlapping_regions, back_caption)
gemini_prompt += f"\nGenre for story: {genre}\n"
st.write("\nPrompt for Gemini:")
st.write(gemini_prompt)
# Generate and display the final story
with st.expander("Step 8: Generated Story"):
try:
story = generate_story_with_genre(prompt=gemini_prompt)
st.write("Generated Story:")
st.write(story)
except Exception as e:
st.error(f"Error generating story: {e}")
if __name__ == "__main__":
main()