Ffftdtd5dtft commited on
Commit
f13c41f
·
verified ·
1 Parent(s): 0515551

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +159 -0
  2. requirements (1).txt +9 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import redis
2
+ import pickle
3
+ import torch
4
+ from PIL import Image
5
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
6
+ from diffusers.utils import export_to_video
7
+ from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
8
+ from audiocraft.models import MusicGen
9
+ import gradio as gr
10
+ from huggingface_hub import snapshot_download, HfApi, HfFolder
11
+ import multiprocessing
12
+ import io
13
+ from dotenv import load_dotenv
14
+ import os
15
+
16
+ # Cargar las variables del archivo .env
17
+ load_dotenv()
18
+
19
+ # Obtener las variables de entorno
20
+ hf_token = os.getenv("HF_TOKEN")
21
+ redis_host = os.getenv("REDIS_HOST")
22
+ redis_port = os.getenv("REDIS_PORT")
23
+ redis_password = os.getenv("REDIS_PASSWORD")
24
+
25
+ # Usar las variables de huggingface
26
+ HfFolder.save_token(hf_token)
27
+
28
+ # Usar las variables de redis
29
+ def connect_to_redis():
30
+ return redis.Redis(host=redis_host, port=redis_port, password=redis_password)
31
+
32
+ def load_object_from_redis(key):
33
+ with connect_to_redis() as redis_client:
34
+ obj_data = redis_client.get(key)
35
+ return pickle.loads(obj_data) if obj_data else None
36
+
37
+ def save_object_to_redis(key, obj):
38
+ with connect_to_redis() as redis_client:
39
+ redis_client.set(key, pickle.dumps(obj))
40
+
41
+ def get_model_or_download(model_id, redis_key, loader_func):
42
+ model = load_object_from_redis(redis_key)
43
+ if not model:
44
+ model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
45
+ save_object_to_redis(redis_key, model)
46
+ return model
47
+
48
+ def generate_image(prompt):
49
+ return text_to_image_pipeline(prompt).images[0]
50
+
51
+ def edit_image_with_prompt(image, prompt, strength=0.75):
52
+ return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
53
+
54
+ def generate_song(prompt, duration=10):
55
+ return music_gen.generate(prompt, duration=duration)
56
+
57
+ def generate_text(prompt):
58
+ return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()
59
+
60
+ def generate_flux_image(prompt):
61
+ return flux_pipeline(
62
+ prompt,
63
+ guidance_scale=0.0,
64
+ num_inference_steps=4,
65
+ max_sequence_length=256,
66
+ generator=torch.Generator("cpu").manual_seed(0)
67
+ ).images[0]
68
+
69
+ def generate_code(prompt):
70
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
71
+ outputs = starcoder_model.generate(inputs)
72
+ return starcoder_tokenizer.decode(outputs[0])
73
+
74
+ def generate_video(prompt):
75
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
76
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
77
+ pipe.enable_model_cpu_offload()
78
+ return export_to_video(pipe(prompt, num_inference_steps=25).frames)
79
+
80
+ def test_model_meta_llama():
81
+ messages = [
82
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
83
+ {"role": "user", "content": "Who are you?"}
84
+ ]
85
+ return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]
86
+
87
+ def train_model(model, dataset, epochs, batch_size, learning_rate):
88
+ output_dir = io.BytesIO()
89
+ training_args = TrainingArguments(
90
+ output_dir=output_dir,
91
+ num_train_epochs=epochs,
92
+ per_device_train_batch_size=batch_size,
93
+ learning_rate=learning_rate,
94
+ )
95
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
96
+ trainer.train()
97
+ save_object_to_redis("trained_model", model)
98
+ save_object_to_redis("training_results", output_dir.getvalue())
99
+
100
+ def run_task(task_queue):
101
+ while True:
102
+ task = task_queue.get()
103
+ if task is None:
104
+ break
105
+ func, args, kwargs = task
106
+ func(*args, **kwargs)
107
+
108
+ task_queue = multiprocessing.Queue()
109
+ num_processes = multiprocessing.cpu_count()
110
+
111
+ processes = []
112
+ for _ in range(num_processes):
113
+ p = multiprocessing.Process(target=run_task, args=(task_queue,))
114
+ p.start()
115
+ processes.append(p)
116
+
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
119
+ img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
120
+ flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
121
+ flux_pipeline.enable_model_cpu_offload()
122
+ music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
123
+ save_object_to_redis("music_gen", music_gen)
124
+ text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
125
+ "text-generation",
126
+ model="google/gemma-2-2b-it",
127
+ model_kwargs={"torch_dtype": torch.bfloat16},
128
+ device=device,
129
+ use_auth_token=hf_token,
130
+ )
131
+ save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
132
+ starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
133
+ starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
134
+ meta_llama_pipeline = transformers_pipeline(
135
+ "text-generation",
136
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
137
+ model_kwargs={"torch_dtype": torch.bfloat16},
138
+ device_map="auto",
139
+ use_auth_token=hf_token
140
+ )
141
+
142
+ gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
143
+ edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
144
+ generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
145
+ generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
146
+ generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
147
+ model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
148
+
149
+ app = gr.TabbedInterface(
150
+ [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
151
+ ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
152
+ )
153
+
154
+ app.launch(share=True)
155
+
156
+ for _ in range(num_processes):
157
+ task_queue.put(None)
158
+ for p in processes:
159
+ p.join()
requirements (1).txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ redis
3
+ diffusers
4
+ transformers
5
+ accelerate
6
+ torch
7
+ gradio
8
+ audiocraft
9
+ huggingface_hub