import os import redis import pickle import torch from PIL import Image from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler from diffusers.utils import export_to_video from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from audiocraft.models import MusicGen import gradio as gr from huggingface_hub import snapshot_download, HfApi, HfFolder import multiprocessing import io import time # Obtener las variables de entorno hf_token = os.getenv("HF_TOKEN") redis_host = os.getenv("REDIS_HOST") redis_port = os.getenv("REDIS_PORT") redis_password = os.getenv("REDIS_PASSWORD") HfFolder.save_token(hf_token) def connect_to_redis(): while True: try: redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password) redis_client.ping() # Verifica si la conexión está activa print("Connected to Redis successfully.") return redis_client except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e: print(f"Connection to Redis failed: {e}. Retrying in 1 second...") time.sleep(1) def reconnect_if_needed(redis_client): try: redis_client.ping() except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError): print("Reconnecting to Redis...") return connect_to_redis() return redis_client def load_object_from_redis(key): redis_client = connect_to_redis() redis_client = reconnect_if_needed(redis_client) try: obj_data = redis_client.get(key) return pickle.loads(obj_data) if obj_data else None except (pickle.PickleError, redis.exceptions.RedisError) as e: print(f"Failed to load object from Redis: {e}") return None def save_object_to_redis(key, obj): redis_client = connect_to_redis() redis_client = reconnect_if_needed(redis_client) try: if not redis_client.exists(key): # Solo guarda si no existe redis_client.set(key, pickle.dumps(obj)) print(f"Object saved to Redis: {key}") except redis.exceptions.RedisError as e: print(f"Failed to save object to Redis: {e}") def get_model_or_download(model_id, redis_key, loader_func): model = load_object_from_redis(redis_key) if model: print(f"Model loaded from Redis: {redis_key}") return model model = loader_func(model_id, torch_dtype=torch.float16) save_object_to_redis(redis_key, model) print(f"Model downloaded and saved to Redis: {redis_key}") return model def generate_image(prompt): redis_key = f"generated_image_{prompt}" image = load_object_from_redis(redis_key) if not image: image = text_to_image_pipeline(prompt).images[0] save_object_to_redis(redis_key, image) return image def edit_image_with_prompt(image, prompt, strength=0.75): redis_key = f"edited_image_{prompt}_{strength}" edited_image = load_object_from_redis(redis_key) if not edited_image: edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0] save_object_to_redis(redis_key, edited_image) return edited_image def generate_song(prompt, duration=10): redis_key = f"generated_song_{prompt}_{duration}" song = load_object_from_redis(redis_key) if not song: song = music_gen.generate(prompt, duration=duration) save_object_to_redis(redis_key, song) return song def generate_text(prompt): redis_key = f"generated_text_{prompt}" text = load_object_from_redis(redis_key) if not text: text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip() save_object_to_redis(redis_key, text) return text def generate_flux_image(prompt): redis_key = f"generated_flux_image_{prompt}" flux_image = load_object_from_redis(redis_key) if not flux_image: flux_image = flux_pipeline( prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0) ).images[0] save_object_to_redis(redis_key, flux_image) return flux_image def generate_code(prompt): redis_key = f"generated_code_{prompt}" code = load_object_from_redis(redis_key) if not code: inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda") outputs = starcoder_model.generate(inputs) code = starcoder_tokenizer.decode(outputs[0]) save_object_to_redis(redis_key, code) return code def generate_video(prompt): redis_key = f"generated_video_{prompt}" video = load_object_from_redis(redis_key) if not video: pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() video = export_to_video(pipe(prompt, num_inference_steps=25).frames) save_object_to_redis(redis_key, video) return video def test_model_meta_llama(): redis_key = "meta_llama_test_response" response = load_object_from_redis(redis_key) if not response: messages = [ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, {"role": "user", "content": "Who are you?"} ] response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip() save_object_to_redis(redis_key, response) return response def train_model(model, dataset, epochs, batch_size, learning_rate): output_dir = io.BytesIO() training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, learning_rate=learning_rate, ) trainer = Trainer(model=model, args=training_args, train_dataset=dataset) trainer.train() save_object_to_redis("trained_model", model) save_object_to_redis("training_results", output_dir.getvalue()) def run_task(task_queue): while True: task = task_queue.get() if task is None: break func, args, kwargs = task func(*args, **kwargs) task_queue = multiprocessing.Queue() num_processes = multiprocessing.cpu_count() processes = [] for _ in range(num_processes): p = multiprocessing.Process(target=run_task, args=(task_queue,)) p.start() processes.append(p) device = "cuda" if torch.cuda.is_available() else "cpu" text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device) img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device) flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained) flux_pipeline.enable_model_cpu_offload() music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token) save_object_to_redis("music_gen", music_gen) text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline( "text-generation", model="google/gemini-2-2b-it", model_kwargs={"torch_dtype": torch.bfloat16}, device=device, use_auth_token=hf_token, ) save_object_to_redis("text_gen_pipeline", text_gen_pipeline) starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token) starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token) meta_llama_pipeline = transformers_pipeline( "text-generation", model="meta-llama/Meta-Llama-3.1-8B-Instruct", model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", use_auth_token=hf_token ) gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images") edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images") generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs") generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text") generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images") model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama") app = gr.TabbedInterface( [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab], ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"] ) app.launch(share=True) for _ in range(num_processes): task_queue.put(None) for p in processes: p.join()