Spaces:
Running
Running
import os | |
import redis | |
import pickle | |
import torch | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler | |
from diffusers.utils import export_to_video | |
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
from audiocraft.models import musicgen | |
import gradio as gr | |
from huggingface_hub import snapshot_download, HfApi, HfFolder | |
import multiprocessing | |
import io | |
import time | |
from tqdm import tqdm | |
from google.cloud import storage | |
import json | |
hf_token = os.getenv("HF_TOKEN") | |
redis_host = os.getenv("REDIS_HOST") | |
redis_port = int(os.getenv("REDIS_PORT", 6379)) | |
redis_password = os.getenv("REDIS_PASSWORD") | |
gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS")) | |
gcs_bucket_name = os.getenv("GCS_BUCKET_NAME") | |
HfFolder.save_token(hf_token) | |
storage_client = storage.Client.from_service_account_info(gcs_credentials) | |
def connect_to_redis(): | |
while True: | |
try: | |
redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password) | |
redis_client.ping() | |
return redis_client | |
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e: | |
print(f"Connection to Redis failed: {e}. Retrying in 1 second...") | |
time.sleep(1) | |
def reconnect_if_needed(redis_client): | |
try: | |
redis_client.ping() | |
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError): | |
print("Reconnecting to Redis...") | |
return connect_to_redis() | |
return redis_client | |
def load_object_from_redis(key): | |
redis_client = connect_to_redis() | |
redis_client = reconnect_if_needed(redis_client) | |
try: | |
obj_data = redis_client.get(key) | |
return pickle.loads(obj_data) if obj_data else None | |
except (pickle.PickleError, redis.exceptions.RedisError) as e: | |
print(f"Failed to load object from Redis: {e}") | |
return None | |
def save_object_to_redis(key, obj): | |
redis_client = connect_to_redis() | |
redis_client = reconnect_if_needed(redis_client) | |
try: | |
redis_client.set(key, pickle.dumps(obj)) | |
except redis.exceptions.RedisError as e: | |
print(f"Failed to save object to Redis: {e}") | |
def upload_to_gcs(bucket_name, blob_name, data): | |
bucket = storage_client.bucket(bucket_name) | |
blob = bucket.blob(blob_name) | |
blob.upload_from_string(data) | |
def download_from_gcs(bucket_name, blob_name): | |
bucket = storage_client.bucket(bucket_name) | |
blob = bucket.blob(blob_name) | |
return blob.download_as_bytes() | |
def get_model_or_download(model_id, redis_key, loader_func): | |
model = load_object_from_redis(redis_key) | |
if model: | |
return model | |
try: | |
with tqdm(total=1, desc=f"Downloading {model_id}") as pbar: | |
model = loader_func(model_id, torch_dtype=torch.float16) | |
pbar.update(1) | |
save_object_to_redis(redis_key, model) | |
model_bytes = pickle.dumps(model) | |
upload_to_gcs(gcs_bucket_name, redis_key, model_bytes) | |
except Exception as e: | |
print(f"Failed to load or save model: {e}") | |
return None | |
def generate_image(prompt): | |
redis_key = f"generated_image:{prompt}" | |
image_bytes = load_object_from_redis(redis_key) | |
if not image_bytes: | |
try: | |
with tqdm(total=1, desc="Generating image") as pbar: | |
image = text_to_image_pipeline(prompt).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, image_bytes) | |
except Exception as e: | |
print(f"Failed to generate image: {e}") | |
return None | |
return image_bytes | |
def edit_image_with_prompt(image_bytes, prompt, strength=0.75): | |
redis_key = f"edited_image:{prompt}:{strength}" | |
edited_image_bytes = load_object_from_redis(redis_key) | |
if not edited_image_bytes: | |
try: | |
image = Image.open(io.BytesIO(image_bytes)) | |
with tqdm(total=1, desc="Editing image") as pbar: | |
edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
edited_image.save(buffered, format="JPEG") | |
edited_image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, edited_image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, edited_image_bytes) | |
except Exception as e: | |
print(f"Failed to edit image: {e}") | |
return None | |
return edited_image_bytes | |
def generate_song(prompt, duration=10): | |
redis_key = f"generated_song:{prompt}:{duration}" | |
song_bytes = load_object_from_redis(redis_key) | |
if not song_bytes: | |
try: | |
with tqdm(total=1, desc="Generating song") as pbar: | |
song = music_gen.generate([prompt], duration=[duration]) | |
pbar.update(1) | |
song_bytes = song[0].getvalue() | |
save_object_to_redis(redis_key, song_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, song_bytes) | |
except Exception as e: | |
print(f"Failed to generate song: {e}") | |
return None | |
return song_bytes | |
def generate_text(prompt): | |
redis_key = f"generated_text:{prompt}" | |
text = load_object_from_redis(redis_key) | |
if not text: | |
try: | |
with tqdm(total=1, desc="Generating text") as pbar: | |
text = text_gen_pipeline(prompt, max_new_tokens=256)[0]["generated_text"].strip() | |
pbar.update(1) | |
save_object_to_redis(redis_key, text) | |
upload_to_gcs(gcs_bucket_name, redis_key, text.encode()) | |
except Exception as e: | |
print(f"Failed to generate text: {e}") | |
return None | |
return text | |
def generate_flux_image(prompt): | |
redis_key = f"generated_flux_image:{prompt}" | |
flux_image_bytes = load_object_from_redis(redis_key) | |
if not flux_image_bytes: | |
try: | |
with tqdm(total=1, desc="Generating FLUX image") as pbar: | |
flux_image = flux_pipeline( | |
prompt, | |
guidance_scale=0.0, | |
num_inference_steps=4, | |
max_sequence_length=256, | |
generator=torch.Generator("cpu").manual_seed(0) | |
).images[0] | |
pbar.update(1) | |
buffered = io.BytesIO() | |
flux_image.save(buffered, format="JPEG") | |
flux_image_bytes = buffered.getvalue() | |
save_object_to_redis(redis_key, flux_image_bytes) | |
upload_to_gcs(gcs_bucket_name, redis_key, flux_image_bytes) | |
except Exception as e: | |
print(f"Failed to generate flux image: {e}") | |
return None | |
return flux_image_bytes | |
def generate_code(prompt): | |
redis_key = f"generated_code:{prompt}" | |
code = load_object_from_redis(redis_key) | |
if not code: | |
try: | |
with tqdm(total=1, desc="Generating code") as pbar: | |
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to(starcoder_model.device) | |
outputs = starcoder_model.generate(inputs) | |
code = starcoder_tokenizer.decode(outputs[0]) | |
pbar.update(1) | |
save_object_to_redis(redis_key, code) | |
upload_to_gcs(gcs_bucket_name, redis_key, code.encode()) | |
except Exception as e: | |
print(f"Failed to generate code: {e}") | |
return None | |
return code | |
def generate_video(prompt): | |
redis_key = f"generated_video:{prompt}" | |
video = load_object_from_redis(redis_key) | |
if not video: | |
try: | |
with tqdm(total=1, desc="Generating video") as pbar: | |
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_model_cpu_offload() | |
video = export_to_video(pipe(prompt, num_inference_steps=25).frames) | |
pbar.update(1) | |
save_object_to_redis(redis_key, video) | |
upload_to_gcs(gcs_bucket_name, redis_key, video.encode()) | |
except Exception as e: | |
print(f"Failed to generate video: {e}") | |
return None | |
return video | |
def test_model_meta_llama(): | |
redis_key = "meta_llama_test_response" | |
response = load_object_from_redis(redis_key) | |
if not response: | |
try: | |
messages = [ | |
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, | |
{"role": "user", "content": "Who are you?"} | |
] | |
with tqdm(total=1, desc="Testing Meta-Llama") as pbar: | |
response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip() | |
pbar.update(1) | |
save_object_to_redis(redis_key, response) | |
upload_to_gcs(gcs_bucket_name, redis_key, response.encode()) | |
except Exception as e: | |
print(f"Failed to test Meta-Llama: {e}") | |
return None | |
return response | |
def train_model(model, dataset, epochs, batch_size, learning_rate): | |
output_dir = io.BytesIO() | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
num_train_epochs=epochs, | |
per_device_train_batch_size=batch_size, | |
learning_rate=learning_rate, | |
) | |
trainer = Trainer(model=model, args=training_args, train_dataset=dataset) | |
try: | |
with tqdm(total=epochs, desc="Training model") as pbar: | |
trainer.train() | |
pbar.update(epochs) | |
save_object_to_redis("trained_model", model) | |
save_object_to_redis("training_results", output_dir.getvalue()) | |
upload_to_gcs(gcs_bucket_name, "trained_model", pickle.dumps(model)) | |
upload_to_gcs(gcs_bucket_name, "training_results", output_dir.getvalue()) | |
except Exception as e: | |
print(f"Failed to train model: {e}") | |
def run_task(task_queue): | |
while True: | |
task = task_queue.get() | |
if task is None: | |
break | |
func, args, kwargs = task | |
try: | |
func(*args, **kwargs) | |
except Exception as e: | |
print(f"Failed to run task: {e}") | |
task_queue = multiprocessing.Queue() | |
num_processes = multiprocessing.cpu_count() | |
processes = [] | |
for _ in range(num_processes): | |
p = multiprocessing.Process(target=run_task, args=(task_queue,)) | |
p.start() | |
processes.append(p) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained) | |
img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained) | |
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained) | |
text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b") | |
music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody') | |
meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline) | |
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder") | |
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder") | |
gen_image_tab = gr.Interface(fn=generate_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate Image") | |
edit_image_tab = gr.Interface(fn=edit_image_with_prompt, inputs=[gr.Image(type="pil", label="Image:"), gr.Textbox(label="Prompt:"), gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], outputs=gr.Image(type="pil"), title="Edit Image") | |
generate_song_tab = gr.Interface(fn=generate_song, inputs=[gr.Textbox(label="Prompt:"), gr.Slider(5, 60, 10, step=1, label="Duration (s):")], outputs=gr.Audio(type="numpy"), title="Generate Songs") | |
generate_text_tab = gr.Interface(fn=generate_text, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Text:"), title="Generate Text") | |
generate_flux_image_tab = gr.Interface(fn=generate_flux_image, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Image(type="pil"), title="Generate FLUX Images") | |
generate_code_tab = gr.Interface(fn=generate_code, inputs=gr.Textbox(label="Prompt:"), outputs=gr.Textbox(label="Generated Code:"), title="Generate Code") | |
model_meta_llama_test_tab = gr.Interface(fn=test_model_meta_llama, inputs=None, outputs=gr.Textbox(label="Model Output:"), title="Test Meta-Llama") | |
app = gr.TabbedInterface( | |
[gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, generate_code_tab, model_meta_llama_test_tab], | |
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"] | |
) | |
app.launch(share=True) | |
for _ in range(num_processes): | |
task_queue.put(None) | |
for p in processes: | |
p.join() |