Spaces:
Runtime error
Runtime error
import pinecone | |
import torch | |
import numpy as np | |
from PIL import Image | |
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler | |
from datasets import load_dataset | |
from transformers import pipeline | |
import gradio as gr | |
import pinecone | |
from pinecone import Pinecone | |
pc = Pinecone(api_key="23afd6c8-4e05-4f77-a069-95ad7b18e6cd") | |
from datasets import load_dataset | |
# Load OpenAI CLIP model for embedding generation | |
import open_clip | |
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai') | |
tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
depth_estimator = pipeline("depth-estimation") | |
# Initialize Stable Diffusion ControlNet | |
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.schedulers import UniPCMultistepScheduler | |
import torch | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16).to("cuda") | |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
).to("cuda") | |
# Load your fine-tuned LoRA adapter | |
lora_weights_path = "rohith2812/atoi-lora-finetuned-v1" # Replace with your LoRA weight file path | |
pipe.unet.load_attn_procs(lora_weights_path) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
if is_xformers_available(): | |
pipe.enable_xformers_memory_efficient_attention() | |
print("Pipeline is ready with fine-tuned LoRA adapter!") | |
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") | |
def retrieve_image_from_text_prompt(prompt, selected_index, knowledge_database): | |
""" | |
Retrieve the most relevant image based on the text prompt from the selected Pinecone index. | |
""" | |
dataset = load_dataset(knowledge_database, split="train") | |
# Initialize the Pinecone index dynamically based on user selection | |
index = pc.Index(selected_index) | |
# Generate Embedding for Text | |
text_tokens = tokenizer([prompt]) | |
with torch.no_grad(): | |
query_embedding = model.encode_text(text_tokens).cpu().numpy().flatten() | |
# Query Pinecone | |
results = index.query(vector=query_embedding.tolist(), top_k=1, include_metadata=True, namespace="text_embeddings") | |
if results and "matches" in results and results["matches"]: | |
best_match = results["matches"][0] | |
image_path = best_match["metadata"]["image_path"] | |
description = best_match["metadata"]["description"] | |
# Match the image path to the dataset to retrieve the image | |
for item in dataset: | |
if item["image_path"].endswith(image_path): | |
return {"image": item["image"], "description": description} | |
return None | |
# Function to Generate Depth Map | |
def get_depth_map(image): | |
image = depth_estimator(image)["depth"] | |
image = np.array(image) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
detected_map = torch.from_numpy(image).float() / 255.0 | |
return detected_map.permute(2, 0, 1).unsqueeze(0).half().to("cuda") | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
# Load CLIP model and processor | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
def calculate_clip_score(image, text): | |
"""Calculate CLIP score for an image and text pair.""" | |
inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to("cuda") | |
outputs = clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image # Image-to-text similarity score | |
clip_score = logits_per_image.softmax(dim=1).max().item() | |
return clip_score | |
def audio_to_image(audio, guidance_scale, num_inference_steps, selected_index, knowledge_database): | |
# Initialize Pinecone index based on user selection | |
dataset = load_dataset(knowledge_database, split="train") | |
index = pc.Index(selected_index) | |
print(f"Connected to Pinecone index: {selected_index}") | |
# Step 1: Transcribe Audio | |
sr, y = audio | |
if y.ndim > 1: | |
y = y.mean(axis=1) # Convert to mono | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
transcription = transcriber({"sampling_rate": sr, "raw": y})["text"] | |
print(f"Transcribed Text: {transcription}") | |
# Step 2: Retrieve Image Based on Text Prompt | |
print("Retrieving image from vector database...") | |
retrieved_data = retrieve_image_from_text_prompt(transcription, selected_index, knowledge_database) | |
if not retrieved_data: | |
return transcription, None, None, "No relevant image found.", None | |
retrieved_image = retrieved_data["image"] | |
retrieved_description = retrieved_data["description"] | |
# Step 3: Generate Depth Map | |
print("Generating depth map...") | |
depth_map = get_depth_map(retrieved_image) | |
# Step 4: Enhance Image Using Stable Diffusion | |
print("Enhancing image with Stable Diffusion...") | |
enhanced_image = pipe( | |
prompt=f"{transcription}. Ensure formulas are accurate and text is clean and legible.", | |
image=retrieved_image, | |
control_image=depth_map, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps | |
).images[0] | |
# Step 5: Calculate CLIP Score | |
print("Calculating CLIP Score...") | |
clip_score = calculate_clip_score(enhanced_image, transcription) | |
# Return Retrieved and Enhanced Images with CLIP Score | |
return transcription, retrieved_image, enhanced_image, retrieved_description, clip_score | |
# Gradio Interface Function | |
def gradio_interface(audio, guidance_scale, num_inference_steps, selected_index, knowledge_database): | |
transcription, retrieved_image, enhanced_image, retrieved_description, clip_score = audio_to_image( | |
audio, guidance_scale, num_inference_steps, selected_index, knowledge_database | |
) | |
if enhanced_image is None: | |
return transcription, "No relevant image found.", None, retrieved_description, "N/A" | |
return transcription, retrieved_image, enhanced_image, retrieved_description, clip_score | |
# Enhanced Gradio UI | |
with gr.Blocks(title="Audio-to-Image Generation") as demo: | |
gr.Markdown( | |
""" | |
# π¨ Audio-to-Image Generation with AI | |
Speak into the microphone, and watch as this AI application retrieves a relevant image from the database, | |
enhances it based on your input, and displays its description and CLIP Score. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="numpy", label="π€ Speak Your Prompt") | |
guidance_scale_input = gr.Slider( | |
minimum=1.0, maximum=20.0, step=0.5, value=8.5, label="ποΈ Guidance Scale" | |
) | |
num_inference_steps_input = gr.Slider( | |
minimum=10, maximum=200, step=10, value=100, label="π’ Number of Inference Steps" | |
) | |
index_selection = gr.Dropdown( | |
choices=["project-atoi-v2", "project-atoi"], | |
value="project-atoi-v2", | |
label="ποΈ Select Pinecone Index" | |
) | |
knowledge_database_selection = gr.Dropdown( | |
choices=["rohith2812/atoigeneration-final-data", "rxc5667/3wordsdataset_noduplicates"], | |
value="rxc5667/3wordsdataset_noduplicates", | |
label="π Select Knowledge Database" | |
) | |
submit_button = gr.Button("Generate Image") | |
with gr.Column(): | |
transcription_output = gr.Textbox(label="π Transcribed Prompt") | |
retrieved_image_output = gr.Image(label="πΌοΈ Retrieved Image") | |
enhanced_image_output = gr.Image(label="β¨ Enhanced Image") | |
retrieved_description_output = gr.Textbox(label="π Retrieved Description") | |
clip_score_output = gr.Textbox(label="π CLIP Score") | |
examples = gr.Examples( | |
examples=[["a picture explain line of best fit in linear regression"], ["Support vector machines"], ["A picture explaining multi[ple components in PCA]"]], | |
inputs=[ | |
audio_input, | |
guidance_scale_input, | |
num_inference_steps_input, | |
index_selection, | |
knowledge_database_selection, | |
], | |
outputs=[ | |
transcription_output, | |
retrieved_image_output, | |
enhanced_image_output, | |
retrieved_description_output, | |
clip_score_output, | |
], | |
label="Examples", | |
) | |
submit_button.click( | |
fn=gradio_interface, | |
inputs=[ | |
audio_input, | |
guidance_scale_input, | |
num_inference_steps_input, | |
index_selection, | |
knowledge_database_selection, | |
], | |
outputs=[ | |
transcription_output, | |
retrieved_image_output, | |
enhanced_image_output, | |
retrieved_description_output, | |
clip_score_output, | |
], | |
) | |
# Launch Gradio Interface | |
if __name__ == "__main__": | |
demo.launch() |