Ffftdtd5dtft commited on
Commit
838294b
verified
1 Parent(s): fd4d632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -6,17 +6,16 @@ from PIL import Image
6
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
7
  from diffusers.utils import export_to_video
8
  from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
9
- from audiocraft.models import MusicGen
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download, HfApi, HfFolder
12
  import multiprocessing
13
  import io
14
  import time
15
 
16
- # Obtener las variables de entorno
17
  hf_token = os.getenv("HF_TOKEN")
18
  redis_host = os.getenv("REDIS_HOST")
19
- redis_port = int(os.getenv("REDIS_PORT", 6379)) # Valor predeterminado si no se proporciona
20
  redis_password = os.getenv("REDIS_PASSWORD")
21
 
22
  HfFolder.save_token(hf_token)
@@ -25,8 +24,7 @@ def connect_to_redis():
25
  while True:
26
  try:
27
  redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
28
- redis_client.ping() # Verifica si la conexi贸n est谩 activa
29
- print("Connected to Redis successfully.")
30
  return redis_client
31
  except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
32
  print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
@@ -55,19 +53,16 @@ def save_object_to_redis(key, obj):
55
  redis_client = reconnect_if_needed(redis_client)
56
  try:
57
  redis_client.set(key, pickle.dumps(obj))
58
- print(f"Object saved to Redis: {key}")
59
  except redis.exceptions.RedisError as e:
60
  print(f"Failed to save object to Redis: {e}")
61
 
62
  def get_model_or_download(model_id, redis_key, loader_func):
63
  model = load_object_from_redis(redis_key)
64
  if model:
65
- print(f"Model loaded from Redis: {redis_key}")
66
  return model
67
  try:
68
  model = loader_func(model_id, torch_dtype=torch.float16)
69
  save_object_to_redis(redis_key, model)
70
- print(f"Model downloaded and saved to Redis: {redis_key}")
71
  except Exception as e:
72
  print(f"Failed to load or save model: {e}")
73
  return None
@@ -221,25 +216,26 @@ for _ in range(num_processes):
221
 
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
 
224
- # Cargar modelos
225
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
226
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
227
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
228
  text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b", device=0)
229
- music_gen = load_object_from_redis("music_gen") or MusicGen.from_pretrained('melody')
230
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
 
 
231
 
232
- # Definir interfaces de usuario
233
  gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Image")
234
  edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Image")
235
  generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
236
  generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
237
  generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
 
238
  model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
239
 
240
  app = gr.TabbedInterface(
241
- [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
242
- ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
243
  )
244
 
245
  app.launch(share=True)
@@ -247,4 +243,4 @@ app.launch(share=True)
247
  for _ in range(num_processes):
248
  task_queue.put(None)
249
  for p in processes:
250
- p.join()
 
6
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
7
  from diffusers.utils import export_to_video
8
  from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
9
+ from audiocraft.models import musicgen
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download, HfApi, HfFolder
12
  import multiprocessing
13
  import io
14
  import time
15
 
 
16
  hf_token = os.getenv("HF_TOKEN")
17
  redis_host = os.getenv("REDIS_HOST")
18
+ redis_port = int(os.getenv("REDIS_PORT", 6379))
19
  redis_password = os.getenv("REDIS_PASSWORD")
20
 
21
  HfFolder.save_token(hf_token)
 
24
  while True:
25
  try:
26
  redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
27
+ redis_client.ping()
 
28
  return redis_client
29
  except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
30
  print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
 
53
  redis_client = reconnect_if_needed(redis_client)
54
  try:
55
  redis_client.set(key, pickle.dumps(obj))
 
56
  except redis.exceptions.RedisError as e:
57
  print(f"Failed to save object to Redis: {e}")
58
 
59
  def get_model_or_download(model_id, redis_key, loader_func):
60
  model = load_object_from_redis(redis_key)
61
  if model:
 
62
  return model
63
  try:
64
  model = loader_func(model_id, torch_dtype=torch.float16)
65
  save_object_to_redis(redis_key, model)
 
66
  except Exception as e:
67
  print(f"Failed to load or save model: {e}")
68
  return None
 
216
 
217
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
218
 
 
219
  text_to_image_pipeline = get_model_or_download("stabilityai/stable-diffusion-2", "text_to_image_model", StableDiffusionPipeline.from_pretrained)
220
  img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained)
221
  flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
222
  text_gen_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b", device=0)
223
+ music_gen = load_object_from_redis("music_gen") or musicgen.MusicGen.get_pretrained('melody')
224
  meta_llama_pipeline = get_model_or_download("meta-llama/Meta-Llama-3.1-8B-Instruct", "meta_llama_model", transformers_pipeline)
225
+ starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
226
+ starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
227
 
 
228
  gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Image")
229
  edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Image")
230
  generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
231
  generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
232
  generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
233
+ generate_code_tab = gr.Interface(generate_code, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Code:"), title="Generate Code")
234
  model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
235
 
236
  app = gr.TabbedInterface(
237
+ [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, generate_code_tab, model_meta_llama_test_tab],
238
+ ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Generate Code", "Test Meta-Llama"]
239
  )
240
 
241
  app.launch(share=True)
 
243
  for _ in range(num_processes):
244
  task_queue.put(None)
245
  for p in processes:
246
+ p.join()