LuckyHappyFish's picture
done
a57ce2c
raw
history blame
4.69 kB
import gradio as gr
import librosa
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
import os
import gradio as gr
import sys
print(f"Gradio version: {gr.__version__}")
print(f"Gradio location: {gr.__file__}")
print(f"Python executable: {sys.executable}")
# Ensure that the script uses CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Stable Diffusion model
model_id = "runwayml/stable-diffusion-v1-5" # Updated model ID for better accessibility
try:
stable_diffusion = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
except Exception as e:
print(f"Error loading the model: {e}")
print("Ensure you have the correct model ID and access rights.")
exit(1)
def describe_audio(audio_path):
"""
Generate a textual description based on audio features.
Parameters:
audio_path (str): Path to the audio file.
Returns:
str: Generated description.
"""
try:
# Load the audio file
y, sr = librosa.load(audio_path, sr=None)
# Extract Mel Spectrogram
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
db_spec = librosa.power_to_db(S, ref=np.max)
# Calculate average amplitude and frequency
avg_amplitude = np.mean(db_spec)
spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
avg_frequency = np.mean(spectral_centroids)
# Generate description based on amplitude
if avg_amplitude < -40:
amplitude_desc = "a calm and serene landscape with gentle waves"
elif avg_amplitude < -20:
amplitude_desc = "a vibrant forest with rustling leaves"
else:
amplitude_desc = "a thunderstorm with dark clouds and lightning"
# Generate description based on frequency
if avg_frequency < 2000:
frequency_desc = "under soft, ambient light"
elif avg_frequency < 4000:
frequency_desc = "with vivid and lively colors"
else:
frequency_desc = "in a surreal and dynamic setting"
# Combine descriptions
description = f"{amplitude_desc} {frequency_desc}"
return description
except Exception as e:
print(f"Error processing audio: {e}")
return "an abstract artistic scene"
def generate_image(description):
"""
Generate an image using the Stable Diffusion model based on the description.
Parameters:
description (str): Textual description for image generation.
Returns:
PIL.Image: Generated image.
"""
try:
if device == "cuda":
with torch.autocast("cuda"):
image = stable_diffusion(description).images[0]
else:
image = stable_diffusion(description).images[0]
return image
except Exception as e:
print(f"Error generating image: {e}")
return None
def audio_to_image(audio_file):
"""
Convert an audio file to an artistic image.
Parameters:
audio_file (str): Path to the uploaded audio file.
Returns:
PIL.Image or str: Generated image or error message.
"""
if audio_file is None:
return "No audio file provided."
description = describe_audio(audio_file)
print(f"Generated Description: {description}")
image = generate_image(description)
if image is not None:
return image
else:
return "Failed to generate image."
# Gradio Interface
title = "🎵 Audio to Artistic Image Converter 🎨"
description_text = """
Upload an audio file, and this app will generate an artistic image based on the sound's characteristics.
"""
# Define example paths
example_paths = [
"example_audio/calm_ocean.wav",
"example_audio/rustling_leaves.wav",
"example_audio/thunderstorm.wav",
]
# Verify example files exist
valid_examples = []
for path in example_paths:
if os.path.isfile(path):
valid_examples.append([path])
else:
print(f"Example file not found: {path}")
if not os.path.exists("example_audio"):
os.makedirs("example_audio")
print("Please add some example audio files in the 'example_audio' directory.")
interface = gr.Interface(
fn=audio_to_image,
inputs=gr.Audio(source="upload", type="filepath"),
outputs=gr.Image(type="pil"),
title=title,
description=description_text,
examples=valid_examples if valid_examples else None,
allow_flagging="never",
theme="default"
)
if __name__ == "__main__":
interface.launch()