|
import torch |
|
from transformers import ( |
|
BlipProcessor, |
|
BlipForConditionalGeneration, |
|
) |
|
from ultralytics import YOLO |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
from dotenv import load_dotenv |
|
import os |
|
import torchvision.models as models |
|
from torchvision import transforms |
|
import warnings |
|
import google.generativeai as genai |
|
import streamlit as st |
|
from MiDaS.midas.midas_net import MidasNet |
|
import io |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
load_dotenv() |
|
API_KEY = os.getenv('API_KEY') |
|
|
|
def check_overlap(bbox1, bbox2): |
|
""" |
|
Checks if two bounding boxes overlap. |
|
Each bbox is (x_min, y_min, x_max, y_max) |
|
""" |
|
x1_min, y1_min, x1_max, y1_max = bbox1 |
|
x2_min, y2_min, x2_max, y2_max = bbox2 |
|
|
|
|
|
if x1_max < x2_min or x2_max < x1_min: |
|
return False |
|
|
|
if y1_max < y2_min or y2_max < y1_min: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def draw_bounding_boxes(image: Image.Image, objects: list) -> Image.Image: |
|
"""Draws bounding boxes and labels on the image.""" |
|
image_draw = image.copy() |
|
draw_obj = ImageDraw.Draw(image_draw) |
|
font = ImageFont.load_default() |
|
for obj in objects: |
|
x_min, y_min, x_max, y_max = obj["bounding_box"] |
|
label = obj["label"] |
|
caption = obj.get("caption", "") |
|
depth_position = obj.get("depth_position", "Unknown") |
|
color = "red" if depth_position == "Front" else "blue" |
|
draw_obj.rectangle([x_min, y_min, x_max, y_max], outline=color, width=2) |
|
text = f"{label}: {caption}" |
|
try: |
|
text_width = draw_obj.textlength(text, font=font) |
|
ascent, descent = font.getmetrics() |
|
text_height = ascent + descent |
|
except Exception as e: |
|
print(f"Error calculating text size for '{text}': {e}") |
|
text_width, text_height = 0, 0 |
|
if text_height <= 0: |
|
text_height = 10 |
|
draw_obj.rectangle( |
|
[x_min, y_min - text_height, x_min + text_width, y_min], |
|
fill=color, |
|
) |
|
draw_obj.text( |
|
(x_min, y_min - text_height), |
|
text, |
|
fill="white", |
|
font=font, |
|
) |
|
return image_draw |
|
|
|
|
|
class ImagePreprocessor: |
|
"""Preprocesses the input image for further analysis.""" |
|
|
|
@staticmethod |
|
def preprocess(image: Image.Image) -> Image.Image: |
|
"""Applies preprocessing steps to the image.""" |
|
|
|
image = image.convert("RGB") |
|
|
|
max_size = 800 |
|
if max(image.size) > max_size: |
|
image.thumbnail((max_size, max_size), Image.LANCZOS) |
|
return image |
|
|
|
|
|
class ObjectDetector: |
|
"""Detects objects in images using YOLOv8.""" |
|
|
|
def __init__(self, model_path: str, confidence_threshold: float = 0.5): |
|
self.model = YOLO(model_path) |
|
self.confidence_threshold = confidence_threshold |
|
print(f"Initialized ObjectDetector with model: {model_path}") |
|
|
|
def detect_objects(self, image: Image.Image) -> list: |
|
"""Detects objects in the image and returns bounding boxes and labels.""" |
|
results = self.model.predict( |
|
source=np.array(image), conf=self.confidence_threshold, verbose=False |
|
) |
|
detections = [] |
|
for result in results[0].boxes: |
|
x_min, y_min, x_max, y_max = result.xyxy[0].cpu().numpy() |
|
conf = result.conf.cpu().item() |
|
cls = int(result.cls.cpu().item()) |
|
label = self.model.names[cls] |
|
detections.append( |
|
{ |
|
"label": label, |
|
"bounding_box": ( |
|
int(x_min), |
|
int(y_min), |
|
int(x_max), |
|
int(y_max), |
|
), |
|
"confidence": conf, |
|
} |
|
) |
|
print( |
|
f"Detected {label} with confidence {conf:.2f} at [{int(x_min)}, {int(y_min)}, {int(x_max)}, {int(y_max)}]" |
|
) |
|
return detections |
|
def draw_detections(self, image: Image.Image, detections: list) -> Image.Image: |
|
"""Draws bounding boxes and labels on the image.""" |
|
draw = ImageDraw.Draw(image) |
|
|
|
for detection in detections: |
|
x_min, y_min, x_max, y_max = detection["bounding_box"] |
|
label = detection["label"] |
|
confidence = detection["confidence"] |
|
|
|
|
|
draw.rectangle( |
|
[(x_min, y_min), (x_max, y_max)], |
|
outline="red", |
|
width=3, |
|
) |
|
|
|
text = f"{label} ({confidence:.2f})" |
|
text_position = (x_min, y_min - 10) |
|
draw.rectangle( |
|
[text_position, (x_min + len(text) * 6, y_min)], |
|
fill="red", |
|
) |
|
draw.text( |
|
text_position, |
|
text, |
|
fill="white", |
|
) |
|
|
|
return image |
|
|
|
|
|
class DepthEstimator: |
|
"""Estimates depth maps for images using a locally saved MiDaS v2.1 Small model.""" |
|
|
|
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): |
|
self.device = torch.device(device) |
|
print(f"Initialized DepthEstimator with model path: {model_path}") |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
self.model = MidasNet(model_path, non_negative=True).to(self.device) |
|
self.model.eval() |
|
self.transform = self._get_transform() |
|
|
|
def _get_transform(self): |
|
"""Defines the transformation pipeline for input images.""" |
|
return transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(256), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
]) |
|
|
|
def get_depth_map(self, image: Image.Image) -> np.ndarray: |
|
"""Generates a normalized depth map for the given image.""" |
|
|
|
input_batch = self.transform(image).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
|
|
prediction = self.model(input_batch) |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=image.size[::-1], |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
depth_map = prediction.cpu().numpy() |
|
|
|
normalized_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) |
|
return normalized_depth |
|
|
|
class EnvironmentClassifier: |
|
"""Classifies the overall environment using Places365 pre-trained ResNet50.""" |
|
|
|
def __init__(self): |
|
model_file = "resnet50_places365.pth.tar" |
|
print(f"Initialized EnvironmentClassifier with model: {model_file}") |
|
self.model = models.resnet50(num_classes=365) |
|
checkpoint = torch.load(model_file, map_location=device) |
|
state_dict = { |
|
str.replace(k, "module.", ""): v |
|
for k, v in checkpoint["state_dict"].items() |
|
} |
|
self.model.load_state_dict(state_dict) |
|
self.model.eval() |
|
self.model.to(device) |
|
|
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize((256, 256)), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
), |
|
] |
|
) |
|
|
|
|
|
file_name = "categories_places365.txt" |
|
self.classes = [] |
|
with open(file_name) as class_file: |
|
for line in class_file: |
|
self.classes.append(line.strip().split(" ")[0][3:]) |
|
self.classes = tuple(self.classes) |
|
|
|
def classify_environment(self, image: Image.Image) -> str: |
|
"""Classifies the environment of the image.""" |
|
input_img = self.transform(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
logits = self.model(input_img) |
|
probs = torch.nn.functional.softmax(logits, dim=1) |
|
top_prob, top_idx = torch.topk(probs, 1) |
|
label = self.classes[top_idx.item()] |
|
return label |
|
|
|
|
|
class ObjectCaptioner: |
|
"""Generates captions for given image crops using BLIP model.""" |
|
|
|
def __init__(self): |
|
model_name = "Salesforce/blip-image-captioning-base" |
|
print(f"Initialized ObjectCaptioner with model: {model_name}") |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
self.processor = BlipProcessor.from_pretrained(model_name) |
|
self.model = BlipForConditionalGeneration.from_pretrained( |
|
model_name |
|
).to(device) |
|
|
|
def generate_caption(self, image: Image.Image) -> str: |
|
"""Generates a caption for a given image crop.""" |
|
inputs = self.processor(image, return_tensors="pt").to(device) |
|
output = self.model.generate( |
|
**inputs, |
|
max_new_tokens=50, |
|
num_beams=5, |
|
no_repeat_ngram_size=2, |
|
early_stopping=True, |
|
) |
|
caption = self.processor.decode(output[0], skip_special_tokens=True) |
|
return caption |
|
|
|
def generate_story_with_genre(prompt): |
|
""" |
|
Takes inputs and calls the Gemini API with a specific genre for the story. |
|
""" |
|
genai.configure(api_key=API_KEY) |
|
try: |
|
model = genai.GenerativeModel("gemini-1.5-flash") |
|
generation_config = { |
|
"temperature": 0.3, |
|
} |
|
response = model.generate_content(contents=prompt, generation_config=generation_config) |
|
return response.text |
|
|
|
except Exception as e: |
|
return f"Error occurred during the API call: {str(e)}" |
|
|
|
|
|
|
|
def construct_gemini_prompt(environment, objects, overlapping_regions, back_caption): |
|
""" |
|
Constructs a prompt for Gemini using the outputs of the models. |
|
""" |
|
prompt = f"The scene is classified as a '{environment}'.\n" |
|
prompt += "In the scene, the following objects are present:\n" |
|
|
|
for obj in objects: |
|
if obj["depth_position"] == "Front": |
|
label = obj["label"] |
|
caption = obj.get("caption", "") |
|
prompt += f"- {label}: {caption}\n" |
|
|
|
if overlapping_regions: |
|
prompt += "\nThere are overlapping regions involving the following objects:\n" |
|
for region in overlapping_regions: |
|
labels = [obj["label"] for obj in region["objects"]] |
|
caption = region.get("caption", "") |
|
prompt += f"- Overlapping {', '.join(labels)}: {caption}\n" |
|
|
|
if back_caption: |
|
prompt += "\nThe background of the scene includes:\n" |
|
prompt += f"- {back_caption}\n" |
|
|
|
prompt += "\nPlease create a realistic-like story that describes the scene, incorporating the objects and their interactions only strictly." |
|
|
|
return prompt |
|
|
|
|
|
def main(): |
|
"""Main function to process the image and generate descriptions in a Streamlit app.""" |
|
st.set_page_config(page_title="The Art of Visual Storytelling", layout="wide") |
|
st.title("The Art of Visual Storytelling...") |
|
st.subheader("Convert an image to a story") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Upload and Settings") |
|
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
genre = st.selectbox( |
|
"Select Genre", |
|
["Normal", "Adventure", "Fantasy", "Sci-Fi", "Mystery", "Romance", "Horror", "Comedy"] |
|
) |
|
|
|
if image_file: |
|
image = Image.open(image_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
if st.button("Generate Story"): |
|
with st.spinner("Processing..."): |
|
image_path = "input_image.jpg" |
|
image.save(image_path) |
|
|
|
try: |
|
main_with_image(image_path, genre) |
|
except Exception as e: |
|
st.error(f"Error: {e}") |
|
|
|
def main_with_image(image_path: str, genre: str): |
|
"""Main processing function from original code, adapted for Streamlit.""" |
|
if not os.path.exists(image_path): |
|
st.error(f"Error: Image path '{image_path}' does not exist.") |
|
return |
|
|
|
|
|
try: |
|
preprocessor = ImagePreprocessor() |
|
image = Image.open(image_path).convert("RGB") |
|
image = preprocessor.preprocess(image) |
|
object_detector = ObjectDetector(model_path="yolov8n.pt", confidence_threshold=0.5) |
|
depth_estimator = DepthEstimator(model_path="midas_v21_384.pt") |
|
environment_classifier = EnvironmentClassifier() |
|
captioner = ObjectCaptioner() |
|
except Exception as e: |
|
st.error(f"Error initializing components: {e}") |
|
return |
|
|
|
|
|
with st.expander("Step 1: Classify the Environment"): |
|
try: |
|
environment = environment_classifier.classify_environment(image) |
|
st.write(f"Environment: {environment}") |
|
except Exception as e: |
|
st.error(f"Error classifying environment: {e}") |
|
|
|
|
|
with st.expander("Step 2: Depth Estimation"): |
|
try: |
|
depth_map = depth_estimator.get_depth_map(image) |
|
st.write(f"Depth map generated with shape {depth_map.shape}, min {depth_map.min():.2f}, max {depth_map.max():.2f}") |
|
buffer = io.BytesIO() |
|
plt.imsave(buffer, depth_map, cmap="viridis", format="png") |
|
buffer.seek(0) |
|
|
|
|
|
st.image(buffer, caption="Depth Map", use_container_width=True) |
|
except Exception as e: |
|
st.error(f"Error estimating depth: {e}") |
|
|
|
|
|
with st.expander("Step 3: Object Detection"): |
|
try: |
|
detections = object_detector.detect_objects(image) |
|
st.write(f"Detected {len(detections)} objects.") |
|
for detection in detections: |
|
st.write(f"- {detection['label']} at {detection['bounding_box']}") |
|
annotated_image = object_detector.draw_detections(image.copy(), detections) |
|
|
|
st.image(annotated_image, caption="Detected Objects", use_container_width=True) |
|
except Exception as e: |
|
st.error(f"Error detecting objects: {e}") |
|
|
|
|
|
with st.expander("Step 4: Process Objects Based on Depth"): |
|
filtered_objects = [] |
|
if depth_map is not None and len(detections) > 0: |
|
for detection in detections: |
|
x_min, y_min, x_max, y_max = detection["bounding_box"] |
|
x_min = max(0, x_min) |
|
y_min = max(0, y_min) |
|
x_max = min(image.width - 1, x_max) |
|
y_max = min(image.height - 1, y_max) |
|
object_depth = depth_map[y_min:y_max, x_min:x_max] |
|
average_depth = np.mean(object_depth) if object_depth.size > 0 else 0.0 |
|
depth_position = "Front" if average_depth >= 0.5 else "Back" |
|
detection["depth_position"] = depth_position |
|
detection["average_depth"] = average_depth |
|
filtered_objects.append(detection) |
|
st.write(f"Object {detection['label']} has average depth {average_depth:.2f}, depth position: {depth_position}") |
|
else: |
|
st.warning("No depth map or detections available.") |
|
for detection in detections: |
|
detection["depth_position"] = "Unknown" |
|
detection["average_depth"] = None |
|
filtered_objects.append(detection) |
|
|
|
|
|
with st.expander("Step 5: Find Overlapping Bounding Boxes"): |
|
try: |
|
overlapping_regions = [] |
|
num_objects = len(filtered_objects) |
|
for i in range(num_objects): |
|
obj1 = filtered_objects[i] |
|
for j in range(i + 1, num_objects): |
|
obj2 = filtered_objects[j] |
|
if obj1["depth_position"] == "Front" and obj2["depth_position"] == "Front": |
|
bbox1 = obj1["bounding_box"] |
|
bbox2 = obj2["bounding_box"] |
|
overlap = check_overlap(bbox1, bbox2) |
|
if overlap: |
|
|
|
x1_min, y1_min, x1_max, y1_max = bbox1 |
|
x2_min, y2_min, x2_max, y2_max = bbox2 |
|
x_overlap_min = max(x1_min, x2_min) |
|
y_overlap_min = max(y1_min, y2_min) |
|
x_overlap_max = min(x1_max, x2_max) |
|
y_overlap_max = min(y1_max, y2_max) |
|
overlapping_region = { |
|
"bounding_box": ( |
|
x_overlap_min, |
|
y_overlap_min, |
|
x_overlap_max, |
|
y_overlap_max, |
|
), |
|
"objects": [obj1, obj2], |
|
"depth_position": obj1["depth_position"], |
|
} |
|
overlapping_regions.append(overlapping_region) |
|
st.write(f"Found {len(overlapping_regions)} overlapping regions.") |
|
except Exception as e: |
|
st.error(f"Error processing overlapping bounding boxes: {e}") |
|
|
|
|
|
with st.expander("Step 6: Caption Objects"): |
|
front_objects = [obj for obj in filtered_objects if obj["depth_position"] == "Front"] |
|
back_objects = [obj for obj in filtered_objects if obj["depth_position"] == "Back"] |
|
|
|
|
|
for obj in front_objects: |
|
try: |
|
x_min, y_min, x_max, y_max = obj["bounding_box"] |
|
cropped_image = image.crop((x_min, y_min, x_max, y_max)) |
|
obj["caption"] = captioner.generate_caption(cropped_image) |
|
st.write(f"- {obj['label']}: {obj['caption']}") |
|
except Exception as e: |
|
obj["caption"] = "" |
|
st.error(f"Error generating caption for {obj['label']}: {e}") |
|
|
|
|
|
for region in overlapping_regions: |
|
x_min, y_min, x_max, y_max = region["bounding_box"] |
|
try: |
|
cropped_image = image.crop((x_min, y_min, x_max, y_max)) |
|
caption = captioner.generate_caption(cropped_image) |
|
region["caption"] = caption |
|
labels = [obj["label"] for obj in region["objects"]] |
|
st.write(f"- Overlapping {labels}: {caption}") |
|
except Exception as e: |
|
region["caption"] = "" |
|
st.error(f"Error generating caption for overlapping region: {e}") |
|
|
|
|
|
back_caption = "" |
|
if back_objects: |
|
x_mins = [obj["bounding_box"][0] for obj in back_objects] |
|
y_mins = [obj["bounding_box"][1] for obj in back_objects] |
|
x_maxs = [obj["bounding_box"][2] for obj in back_objects] |
|
y_maxs = [obj["bounding_box"][3] for obj in back_objects] |
|
x_min = min(x_mins) |
|
y_min = min(y_mins) |
|
x_max = max(x_maxs) |
|
y_max = max(y_maxs) |
|
try: |
|
cropped_image = image.crop((x_min, y_min, x_max, y_max)) |
|
back_caption = captioner.generate_caption(cropped_image) |
|
st.write(f"- Back objects caption: {back_caption}") |
|
except Exception as e: |
|
st.error(f"Error generating caption for Back objects: {e}") |
|
|
|
|
|
with st.expander("Step 7: Construct Prompt for Gemini"): |
|
gemini_prompt = construct_gemini_prompt(environment, filtered_objects, overlapping_regions, back_caption) |
|
gemini_prompt += f"\nGenre for story: {genre}\n" |
|
st.write("\nPrompt for Gemini:") |
|
st.write(gemini_prompt) |
|
|
|
|
|
with st.expander("Step 8: Generated Story"): |
|
try: |
|
story = generate_story_with_genre(prompt=gemini_prompt) |
|
st.write("Generated Story:") |
|
st.write(story) |
|
except Exception as e: |
|
st.error(f"Error generating story: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |