AtoI-App / app.py
rohith2812's picture
Update app.py
75d0925 verified
import pinecone
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
from datasets import load_dataset
from transformers import pipeline
import gradio as gr
import pinecone
from pinecone import Pinecone
pc = Pinecone(api_key="23afd6c8-4e05-4f77-a069-95ad7b18e6cd")
from datasets import load_dataset
# Load OpenAI CLIP model for embedding generation
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
depth_estimator = pipeline("depth-estimation")
# Initialize Stable Diffusion ControlNet
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.schedulers import UniPCMultistepScheduler
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16).to("cuda")
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")
# Load your fine-tuned LoRA adapter
lora_weights_path = "rohith2812/atoi-lora-finetuned-v1" # Replace with your LoRA weight file path
pipe.unet.load_attn_procs(lora_weights_path)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
if is_xformers_available():
pipe.enable_xformers_memory_efficient_attention()
print("Pipeline is ready with fine-tuned LoRA adapter!")
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
def retrieve_image_from_text_prompt(prompt, selected_index, knowledge_database):
"""
Retrieve the most relevant image based on the text prompt from the selected Pinecone index.
"""
dataset = load_dataset(knowledge_database, split="train")
# Initialize the Pinecone index dynamically based on user selection
index = pc.Index(selected_index)
# Generate Embedding for Text
text_tokens = tokenizer([prompt])
with torch.no_grad():
query_embedding = model.encode_text(text_tokens).cpu().numpy().flatten()
# Query Pinecone
results = index.query(vector=query_embedding.tolist(), top_k=1, include_metadata=True, namespace="text_embeddings")
if results and "matches" in results and results["matches"]:
best_match = results["matches"][0]
image_path = best_match["metadata"]["image_path"]
description = best_match["metadata"]["description"]
# Match the image path to the dataset to retrieve the image
for item in dataset:
if item["image_path"].endswith(image_path):
return {"image": item["image"], "description": description}
return None
# Function to Generate Depth Map
def get_depth_map(image):
image = depth_estimator(image)["depth"]
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = torch.from_numpy(image).float() / 255.0
return detected_map.permute(2, 0, 1).unsqueeze(0).half().to("cuda")
from transformers import CLIPProcessor, CLIPModel
import torch
# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def calculate_clip_score(image, text):
"""Calculate CLIP score for an image and text pair."""
inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to("cuda")
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image # Image-to-text similarity score
clip_score = logits_per_image.softmax(dim=1).max().item()
return clip_score
def audio_to_image(audio, guidance_scale, num_inference_steps, selected_index, knowledge_database):
# Initialize Pinecone index based on user selection
dataset = load_dataset(knowledge_database, split="train")
index = pc.Index(selected_index)
print(f"Connected to Pinecone index: {selected_index}")
# Step 1: Transcribe Audio
sr, y = audio
if y.ndim > 1:
y = y.mean(axis=1) # Convert to mono
y = y.astype(np.float32)
y /= np.max(np.abs(y))
transcription = transcriber({"sampling_rate": sr, "raw": y})["text"]
print(f"Transcribed Text: {transcription}")
# Step 2: Retrieve Image Based on Text Prompt
print("Retrieving image from vector database...")
retrieved_data = retrieve_image_from_text_prompt(transcription, selected_index, knowledge_database)
if not retrieved_data:
return transcription, None, None, "No relevant image found.", None
retrieved_image = retrieved_data["image"]
retrieved_description = retrieved_data["description"]
# Step 3: Generate Depth Map
print("Generating depth map...")
depth_map = get_depth_map(retrieved_image)
# Step 4: Enhance Image Using Stable Diffusion
print("Enhancing image with Stable Diffusion...")
enhanced_image = pipe(
prompt=f"{transcription}. Ensure formulas are accurate and text is clean and legible.",
image=retrieved_image,
control_image=depth_map,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
).images[0]
# Step 5: Calculate CLIP Score
print("Calculating CLIP Score...")
clip_score = calculate_clip_score(enhanced_image, transcription)
# Return Retrieved and Enhanced Images with CLIP Score
return transcription, retrieved_image, enhanced_image, retrieved_description, clip_score
# Gradio Interface Function
def gradio_interface(audio, guidance_scale, num_inference_steps, selected_index, knowledge_database):
transcription, retrieved_image, enhanced_image, retrieved_description, clip_score = audio_to_image(
audio, guidance_scale, num_inference_steps, selected_index, knowledge_database
)
if enhanced_image is None:
return transcription, "No relevant image found.", None, retrieved_description, "N/A"
return transcription, retrieved_image, enhanced_image, retrieved_description, clip_score
# Enhanced Gradio UI
with gr.Blocks(title="Audio-to-Image Generation") as demo:
gr.Markdown(
"""
# 🎨 Audio-to-Image Generation with AI
Speak into the microphone, and watch as this AI application retrieves a relevant image from the database,
enhances it based on your input, and displays its description and CLIP Score.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="numpy", label="🎀 Speak Your Prompt")
guidance_scale_input = gr.Slider(
minimum=1.0, maximum=20.0, step=0.5, value=8.5, label="πŸŽ›οΈ Guidance Scale"
)
num_inference_steps_input = gr.Slider(
minimum=10, maximum=200, step=10, value=100, label="πŸ”’ Number of Inference Steps"
)
index_selection = gr.Dropdown(
choices=["project-atoi-v2", "project-atoi"],
value="project-atoi-v2",
label="πŸ—‚οΈ Select Pinecone Index"
)
knowledge_database_selection = gr.Dropdown(
choices=["rohith2812/atoigeneration-final-data", "rxc5667/3wordsdataset_noduplicates"],
value="rxc5667/3wordsdataset_noduplicates",
label="πŸ“š Select Knowledge Database"
)
submit_button = gr.Button("Generate Image")
with gr.Column():
transcription_output = gr.Textbox(label="πŸ“ Transcribed Prompt")
retrieved_image_output = gr.Image(label="πŸ–ΌοΈ Retrieved Image")
enhanced_image_output = gr.Image(label="✨ Enhanced Image")
retrieved_description_output = gr.Textbox(label="πŸ“œ Retrieved Description")
clip_score_output = gr.Textbox(label="πŸ“Š CLIP Score")
examples = gr.Examples(
examples=[["a picture explain line of best fit in linear regression"], ["Support vector machines"], ["A picture explaining multi[ple components in PCA]"]],
inputs=[
audio_input,
guidance_scale_input,
num_inference_steps_input,
index_selection,
knowledge_database_selection,
],
outputs=[
transcription_output,
retrieved_image_output,
enhanced_image_output,
retrieved_description_output,
clip_score_output,
],
label="Examples",
)
submit_button.click(
fn=gradio_interface,
inputs=[
audio_input,
guidance_scale_input,
num_inference_steps_input,
index_selection,
knowledge_database_selection,
],
outputs=[
transcription_output,
retrieved_image_output,
enhanced_image_output,
retrieved_description_output,
clip_score_output,
],
)
# Launch Gradio Interface
if __name__ == "__main__":
demo.launch()