Spaces:
Running
Running
import os | |
import redis | |
import pickle | |
import torch | |
from PIL import Image | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionImg2ImgPipeline, | |
FluxPipeline, | |
DiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
) | |
from diffusers.utils import export_to_video | |
from transformers import ( | |
pipeline as transformers_pipeline, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GPT2Tokenizer, | |
GPT2Model, | |
) | |
from audiocraft.models import musicgen | |
import gradio as gr | |
from huggingface_hub import snapshot_download, HfApi, HfFolder | |
import multiprocessing | |
import io | |
import time | |
from tqdm import tqdm | |
from google.cloud import storage | |
import json | |
hf_token = os.getenv("HF_TOKEN") | |
redis_host = os.getenv("REDIS_HOST") | |
redis_port = int(os.getenv("REDIS_PORT", 6379)) | |
redis_password = os.getenv("REDIS_PASSWORD") | |
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS")) | |
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME") | |
HfFolder.save_token(hf_token) | |
storage_client = storage.Client.from_service_account_info(gcs_credentials) | |
def connect_to_redis(): | |
while True: | |
try: | |
redis_client = redis.Redis( | |
host=redis_host, port=redis_port, password=redis_password | |
) | |
redis_client.ping() | |
return redis_client | |
except ( | |
redis.exceptions.ConnectionError, | |
redis.exceptions.TimeoutError, | |
BrokenPipeError, | |
) as e: | |
print(f"Connection to Redis failed: {e}. Retrying in 1 second...") | |
time.sleep(1) | |
def reconnect_if_needed(redis_client): | |
try: | |
redis_client.ping() | |
except ( | |
redis.exceptions.ConnectionError, | |
redis.exceptions.TimeoutError, | |
BrokenPipeError, | |
): | |
print("Reconnecting to Redis...") | |
return connect_to_redis() | |
return redis_client | |
def load_object_from_redis(key): | |
redis_client = connect_to_redis() | |
redis_client = reconnect_if_needed(redis_client) | |
try: | |
obj_data = redis_client.get(key) | |
return pickle.loads(obj_data) if obj_data else None | |
except (pickle.PickleError, redis.exceptions.RedisError) as e: | |
print(f"Failed to load object from Redis: {e}") | |
return None | |
def save_object_to_redis(key, obj): | |
redis_client = connect_to_redis() | |
redis_client = reconnect_if_needed(redis_client) | |
try: | |
redis_client.set(key, pickle.dumps(obj)) | |
except redis.exceptions.RedisError as e: | |
print(f"Failed to save object to Redis: {e}") | |
def upload_to_gcs(bucket_name, blob_name, data): | |
bucket = storage_client.bucket(bucket_name) | |
blob = bucket.blob(blob_name) | |
blob.upload_from_string(data) | |
def download_from_gcs(bucket_name, blob_name): | |
bucket = storage_client.bucket(bucket_name) | |
blob = bucket.blob(blob_name) | |
return blob.download_as_bytes() | |
def get_model_or_download(model_id, redis_key, loader_func): | |
model = load_object_from_redis(redis_key) | |
if model: | |
return model | |
try: | |
with tqdm(total=1, desc=f"Downloading {model_id}") as pbar: | |
model = loader_func(model_id, torch_dtype=torch.float16) | |
pbar.update(1) | |
save_object_to_redis(redis_key, model) | |
model_bytes = pickle.dumps(model) | |
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes) | |
return model | |
except Exception as e: | |
print(f"Failed to load or save model: {e}") | |
return None | |
def generate_image(prompt): | |
redis_key = f"generated_image:{prompt}" | |
image_bytes = load_object_from_redis(redis_key) | |
if not image_bytes: | |
try: | |
with tqdm(total=1, desc="Generating image") as pbar: | |
image = text_to_image_pipeline(prompt).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes) | |
except Exception as e: | |
print(f"Failed to generate image: {e}") | |
return None | |
return image_bytes | |
def edit_image_with_prompt(image_bytes, prompt, strength=0.75): | |
redis_key = f"edited_image:{prompt}:{strength}" | |
edited_image_bytes = load_object_from_redis(redis_key) | |
if not edited_image_bytes: | |
try: | |
image = Image.open(io.BytesIO(image_bytes)) | |
with tqdm(total=1, desc="Editing image") as pbar: | |
edited_image = img2img_pipeline( | |
prompt=prompt, image=image, strength=strength | |
).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
edited_image.save(buffered, format="JPEG") | |
edited_image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, edited_image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes) | |
except Exception as e: | |
print(f"Failed to edit image: {e}") | |
return None | |
return edited_image_bytes | |
def generate_song(prompt, duration=10): | |
redis_key = f"generated_song:{prompt}:{duration}" | |
song_bytes = load_object_from_redis(redis_key) | |
if not song_bytes: | |
try: | |
with tqdm(total=1, desc="Generating song") as pbar: | |
song = music_gen(prompt, duration=duration) | |
pbar.update(1) | |
song_bytes = song[0].getvalue() | |
save_object_to_redis(redis_key, song_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes) | |
except Exception as e: | |
print(f"Failed to generate song: {e}") | |
return None | |
return song_bytes | |
def generate_text(prompt): | |
redis_key = f"generated_text:{prompt}" | |
text = load_object_from_redis(redis_key) | |
if not text: | |
try: | |
with tqdm(total=1, desc="Generating text") as pbar: | |
text = text_gen_pipeline(prompt, max_new_tokens=256)[0][ | |
"generated_text" | |
].strip() | |
pbar.update(1) | |
save_object_to_redis(redis_key, text) | |
upload_to_gcs(gcs_bucket_name, redis_key, text.encode()) | |
except Exception as e: | |
print(f"Failed to generate text: {e}") | |
return None | |
return text | |
def generate_flux_image(prompt): | |
redis_key = f"generated_flux_image:{prompt}" | |
flux_image_bytes = load_object_from_redis(redis_key) | |
if not flux_image_bytes: | |
try: | |
with tqdm(total=1, desc="Generating FLUX image") as pbar: | |
flux_image = flux_pipeline( | |
prompt, | |
guidance_scale=0.0, | |
num_inference_steps=4, | |
max_length=256, | |
generator=torch.Generator("cpu").manual_seed(0), | |
).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
flux_image.save(buffered, format="JPEG") | |
flux_image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, flux_image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes) | |
except Exception as e: | |
print(f"Failed to generate flux image: {e}") | |
return None | |
return flux_image_bytes | |
def generate_code(prompt): | |
redis_key = f"generated_code:{prompt}" | |
code = load_object_from_redis(redis_key) | |
if not code: | |
try: | |
with tqdm(total=1, desc="Generating code") as pbar: | |
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to( | |
starcoder_model.device | |
) | |
outputs = starcoder_model.generate(inputs, max_new_tokens=256) | |
code = starcoder_tokenizer.decode(outputs[0]) | |
pbar.update(1) | |
save_object_to_redis(redis_key, code) | |
upload_to_gcs(gcs_bucket_name, redis_key, code.encode()) | |
except Exception as e: | |
print(f"Failed to generate code: {e}") | |
return None | |
return code | |
def test_model_meta_llama(): | |
redis_key = "meta_llama_test_response" | |
response = load_object_from_redis(redis_key) | |
if not response: | |
try: | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a pirate chatbot who always responds in pirate speak!", | |
}, | |
{"role": "user", "content": "Who are you?"}, | |
] | |
with tqdm(total=1, desc="Testing Meta-Llama") as pbar: | |
response = meta_llama_pipeline(messages, max_new_tokens=256)[0][ | |
"generated_text" | |
].strip() | |
pbar.update(1) | |
save_object_to_redis(redis_key, response) | |
upload_to_gcs(gcs_bucket_name, redis_key, response.encode()) | |
except Exception as e: | |
print(f"Failed to test Meta-Llama: {e}") | |
return None | |
return response | |
def generate_image_sdxl(prompt): | |
redis_key = f"generated_image_sdxl:{prompt}" | |
image_bytes = load_object_from_redis(redis_key) | |
if not image_bytes: | |
try: | |
with tqdm(total=1, desc="Generating SDXL image") as pbar: | |
image = base( | |
prompt=prompt, | |
num_inference_steps=40, | |
denoising_end=0.8, | |
output_type="latent", | |
).images | |
image = refiner( | |
prompt=prompt, | |
num_inference_steps=40, | |
denoising_start=0.8, | |
image=image, | |
).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes) | |
except Exception as e: | |
print(f"Failed to generate SDXL image: {e}") | |
return None | |
return image_bytes | |
def generate_musicgen_melody(prompt): | |
redis_key = f"generated_musicgen_melody:{prompt}" | |
song_bytes = load_object_from_redis(redis_key) | |
if not song_bytes: | |
try: | |
with tqdm(total=1, desc="Generating MusicGen melody") as pbar: | |
melody, sr = torchaudio.load("./assets/bach.mp3") | |
wav = music_gen_melody.generate_with_chroma( | |
[prompt], melody[None].expand(3, -1, -1), sr | |
) | |
pbar.update(1) | |
song_bytes = wav[0].getvalue() | |
save_object_to_redis(redis_key, song_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes) | |
except Exception as e: | |
print(f"Failed to generate MusicGen melody: {e}") | |
return None | |
return song_bytes | |
def generate_musicgen_large(prompt): | |
redis_key = f"generated_musicgen_large:{prompt}" | |
song_bytes = load_object_from_redis(redis_key) | |
if not song_bytes: | |
try: | |
with tqdm(total=1, desc="Generating MusicGen large") as pbar: | |
wav = music_gen_large.generate([prompt]) | |
pbar.update(1) | |
song_bytes = wav[0].getvalue() | |
save_object_to_redis(redis_key, song_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes) | |
except Exception as e: | |
print(f"Failed to generate MusicGen large: {e}") | |
return None | |
return song_bytes | |
def transcribe_audio(audio_sample): | |
redis_key = f"transcribed_audio:{hash(audio_sample.tobytes())}" | |
text = load_object_from_redis(redis_key) | |
if not text: | |
try: | |
with tqdm(total=1, desc="Transcribing audio") as pbar: | |
text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"] | |
pbar.update(1) | |
save_object_to_redis(redis_key, text) | |
upload_to_gcs(gcs_bucket_name, redis_key, text.encode()) | |
except Exception as e: | |
print(f"Failed to transcribe audio: {e}") | |
return None | |
return text | |
def generate_mistral_instruct(prompt): | |
redis_key = f"generated_mistral_instruct:{prompt}" | |
response = load_object_from_redis(redis_key) | |
if not response: | |
try: | |
conversation = [{"role": "user", "content": prompt}] | |
with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar: | |
inputs = mistral_instruct_tokenizer.apply_chat_template( | |
conversation, | |
tools=tools, | |
add_generation_prompt=True, | |
return_dict=True, | |
return_tensors="pt", | |
) | |
inputs.to(mistral_instruct_model.device) | |
outputs = mistral_instruct_model.generate( | |
**inputs, max_new_tokens=1000 | |
) | |
response = mistral_instruct_tokenizer.decode( | |
outputs[0], skip_special_tokens=True | |
) | |
pbar.update(1) | |
save_object_to_redis(redis_key, response) | |
upload_to_gcs(gcs_bucket_name, redis_key, response.encode()) | |
except Exception as e: | |
print(f"Failed to generate Mistral Instruct response: {e}") | |
return None | |
return response | |
def generate_mistral_nemo(prompt): | |
redis_key = f"generated_mistral_nemo:{prompt}" | |
response = load_object_from_redis(redis_key) | |
if not response: | |
try: | |
conversation = [{"role": "user", "content": prompt}] | |
with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar: | |
inputs = mistral_nemo_tokenizer.apply_chat_template( | |
conversation, | |
tools=tools, | |
add_generation_prompt=True, | |
return_dict=True, | |
return_tensors="pt", | |
) | |
inputs.to(mistral_nemo_model.device) | |
outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000) | |
response = mistral_nemo_tokenizer.decode( | |
outputs[0], skip_special_tokens=True | |
) | |
pbar.update(1) | |
save_object_to_redis(redis_key, response) | |
upload_to_gcs(gcs_bucket_name, redis_key, response.encode()) | |
except Exception as e: | |
print(f"Failed to generate Mistral Nemo response: {e}") | |
return None | |
return response | |
def generate_gpt2_xl(prompt): | |
redis_key = f"generated_gpt2_xl:{prompt}" | |
response = load_object_from_redis(redis_key) | |
if not response: | |
try: | |
with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar: | |
inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt") | |
outputs = gpt2_xl_model(**inputs) | |
response = gpt2_xl_tokenizer.decode( | |
outputs[0][0], skip_special_tokens=True | |
) | |
pbar.update(1) | |
save_object_to_redis(redis_key, response) | |
upload_to_gcs(gcs_bucket_name, redis_key, response.encode()) | |
except Exception as e: | |
print(f"Failed to generate GPT-2 XL response: {e}") | |
return None | |
return response | |
def answer_question_minicpm(image_bytes, question): | |
redis_key = f"minicpm_answer:{hash(image_bytes)}:{question}" | |
answer = load_object_from_redis(redis_key) | |
if not answer: | |
try: | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
with tqdm(total=1, desc="Answering question with MiniCPM") as pbar: | |
msgs = [{"role": "user", "content": [image, question]}] | |
answer = minicpm_model.chat( | |
image=None, msgs=msgs, tokenizer=minicpm_tokenizer | |
) | |
pbar.update(1) | |
save_object_to_redis(redis_key, answer) | |
upload_to_gcs(gcs_bucket_name, redis_key, answer.encode()) | |
except Exception as e: | |
print(f"Failed to answer question with MiniCPM: {e}") | |
return None | |
return answer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
text_to_image_pipeline = get_model_or_download( | |
"stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained | |
) | |
img2img_pipeline = get_model_or_download( | |
"CompVis/stable-diffusion-v1-4", | |
"img2img_model", | |
StableDiffusionImg2ImgPipeline.from_pretrained, | |
) | |
flux_pipeline = get_model_or_download( | |
"black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained | |
) | |
text_gen_pipeline = transformers_pipeline( | |
"text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b" | |
) | |
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained( | |
"melody" | |
).to(device) | |
meta_llama_pipeline = get_model_or_download( | |
"meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline | |
) | |
starcoder_model = AutoModelForCausalLM.from_pretrained( | |
"bigcode/starcoder" | |
).to(device) | |
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder") | |
base = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to(device) | |
refiner = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
text_encoder_2=base.text_encoder_2, | |
vae=base.vae, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
).to(device) | |
music_gen_melody = musicgen.MusicGen.get_pretrained("melody").to(device) | |
music_gen_melody.set_generation_params(duration=8) | |
music_gen_large = musicgen.MusicGen.get_pretrained("large").to(device) | |
music_gen_large.set_generation_params(duration=8) | |
whisper_pipeline = transformers_pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-small", | |
chunk_length_s=30, | |
device=device, | |
) | |
mistral_instruct_model = AutoModelForCausalLM.from_pretrained( | |
"mistralai/Mistral-Large-Instruct-2407", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
mistral_instruct_tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-Large-Instruct-2407" | |
) | |
mistral_nemo_model = AutoModelForCausalLM.from_pretrained( | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
mistral_nemo_tokenizer = AutoTokenizer.from_pretrained( | |
"mistralai/Mistral-Nemo-Instruct-2407" | |
) | |
gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl") | |
gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl") | |
minicpm_model = AutoModel.from_pretrained( | |
"openbmb/MiniCPM-V-2_6", | |
trust_remote_code=True, | |
attn_implementation="sdpa", | |
torch_dtype=torch.bfloat16, | |
).eval().cuda() | |
minicpm_tokenizer = AutoTokenizer.from_pretrained( | |
"openbmb/MiniCPM-V-2_6", trust_remote_code=True | |
) | |
tools = [] # Define any tools needed for Mistral models | |
gen_image_tab = gr.Interface( | |
fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image" | |
) | |
edit_image_tab = gr.Interface( | |
fn=edit_image_with_prompt, | |
inputs=[ | |
gr.Image(type="pil", label="Image:"), | |
gr.Textbox(label="Prompt:"), | |
gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"), | |
], | |
outputs=gr.Image(type="pil"), | |
title="Edit Image", | |
) | |
generate_song_tab = gr.Interface( | |
fn=generate_song, | |
inputs=[ | |
gr.Textbox(label="Prompt:"), | |
gr.Slider(5, 60, 10, step=1, label="Duration (s):"), | |
], | |
outputs=gr.Audio(type="numpy"), | |
title="Generate Songs", | |
) | |
generate_text_tab = gr.Interface( | |
fn=generate_text, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Textbox(label="Generated Text:"), | |
title="Generate Text", | |
) | |
generate_flux_image_tab = gr.Interface( | |
fn=generate_flux_image, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Image(type="pil"), | |
title="Generate FLUX Images", | |
) | |
generate_code_tab = gr.Interface( | |
fn=generate_code, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Textbox(label="Generated Code:"), | |
title="Generate Code", | |
) | |
model_meta_llama_test_tab = gr.Interface( | |
fn=test_model_meta_llama, | |
inputs=None, | |
outputs=gr.Textbox(label="Model Output:"), | |
title="Test Meta-Llama", | |
) | |
generate_image_sdxl_tab = gr.Interface( | |
fn=generate_image_sdxl, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Image(type="pil"), | |
title="Generate SDXL Image", | |
) | |
generate_musicgen_melody_tab = gr.Interface( | |
fn=generate_musicgen_melody, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Audio(type="numpy"), | |
title="Generate MusicGen Melody", | |
) | |
generate_musicgen_large_tab = gr.Interface( | |
fn=generate_musicgen_large, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Audio(type="numpy"), | |
title="Generate MusicGen Large", | |
) | |
transcribe_audio_tab = gr.Interface( | |
fn=transcribe_audio, | |
inputs=gr.Audio(type="numpy", label="Audio Sample:"), | |
outputs=gr.Textbox(label="Transcribed Text:"), | |
title="Transcribe Audio", | |
) | |
generate_mistral_instruct_tab = gr.Interface( | |
fn=generate_mistral_instruct, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Textbox(label="Mistral Instruct Response:"), | |
title="Generate Mistral Instruct Response", | |
) | |
generate_mistral_nemo_tab = gr.Interface( | |
fn=generate_mistral_nemo, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Textbox(label="Mistral Nemo Response:"), | |
title="Generate Mistral Nemo Response", | |
) | |
generate_gpt2_xl_tab = gr.Interface( | |
fn=generate_gpt2_xl, | |
inputs=gr.Textbox(label="Prompt:"), | |
outputs=gr.Textbox(label="GPT-2 XL Response:"), | |
title="Generate GPT-2 XL Response", | |
) | |
answer_question_minicpm_tab = gr.Interface( | |
fn=answer_question_minicpm, | |
inputs=[ | |
gr.Image(type="pil", label="Image:"), | |
gr.Textbox(label="Question:"), | |
], | |
outputs=gr.Textbox(label="MiniCPM Answer:"), | |
title="Answer Question with MiniCPM", | |
) | |
app = gr.TabbedInterface( | |
[ | |
gen_image_tab, | |
edit_image_tab, | |
generate_song_tab, | |
generate_text_tab, | |
generate_flux_image_tab, | |
generate_code_tab, | |
model_meta_llama_test_tab, | |
generate_image_sdxl_tab, | |
generate_musicgen_melody_tab, | |
generate_musicgen_large_tab, | |
transcribe_audio_tab, | |
generate_mistral_instruct_tab, | |
generate_mistral_nemo_tab, | |
generate_gpt2_xl_tab, | |
answer_question_minicpm_tab, | |
], | |
[ | |
"Generate Image", | |
"Edit Image", | |
"Generate Song", | |
"Generate Text", | |
"Generate FLUX Image", | |
"Generate Code", | |
"Test Meta-Llama", | |
"Generate SDXL Image", | |
"Generate MusicGen Melody", | |
"Generate MusicGen Large", | |
"Transcribe Audio", | |
"Generate Mistral Instruct Response", | |
"Generate Mistral Nemo Response", | |
"Generate GPT-2 XL Response", | |
"Answer Question with MiniCPM", | |
], | |
) | |
app.launch(share=True) |