Spaces:
Running
Running
Ffftdtd5dtft
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ import time
|
|
16 |
# Obtener las variables de entorno
|
17 |
hf_token = os.getenv("HF_TOKEN")
|
18 |
redis_host = os.getenv("REDIS_HOST")
|
19 |
-
redis_port = os.getenv("REDIS_PORT")
|
20 |
redis_password = os.getenv("REDIS_PASSWORD")
|
21 |
|
22 |
HfFolder.save_token(hf_token)
|
@@ -65,88 +65,123 @@ def get_model_or_download(model_id, redis_key, loader_func):
|
|
65 |
if model:
|
66 |
print(f"Model loaded from Redis: {redis_key}")
|
67 |
return model
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
return model
|
72 |
|
73 |
def generate_image(prompt):
|
74 |
redis_key = f"generated_image_{prompt}"
|
75 |
image = load_object_from_redis(redis_key)
|
76 |
if not image:
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
return image
|
80 |
|
81 |
def edit_image_with_prompt(image, prompt, strength=0.75):
|
82 |
redis_key = f"edited_image_{prompt}_{strength}"
|
83 |
edited_image = load_object_from_redis(redis_key)
|
84 |
if not edited_image:
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
return edited_image
|
88 |
|
89 |
def generate_song(prompt, duration=10):
|
90 |
redis_key = f"generated_song_{prompt}_{duration}"
|
91 |
song = load_object_from_redis(redis_key)
|
92 |
if not song:
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
return song
|
96 |
|
97 |
def generate_text(prompt):
|
98 |
redis_key = f"generated_text_{prompt}"
|
99 |
text = load_object_from_redis(redis_key)
|
100 |
if not text:
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
return text
|
104 |
|
105 |
def generate_flux_image(prompt):
|
106 |
redis_key = f"generated_flux_image_{prompt}"
|
107 |
flux_image = load_object_from_redis(redis_key)
|
108 |
if not flux_image:
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
return flux_image
|
118 |
|
119 |
def generate_code(prompt):
|
120 |
redis_key = f"generated_code_{prompt}"
|
121 |
code = load_object_from_redis(redis_key)
|
122 |
if not code:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
127 |
return code
|
128 |
|
129 |
def generate_video(prompt):
|
130 |
redis_key = f"generated_video_{prompt}"
|
131 |
video = load_object_from_redis(redis_key)
|
132 |
if not video:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
return video
|
139 |
|
140 |
def test_model_meta_llama():
|
141 |
redis_key = "meta_llama_test_response"
|
142 |
response = load_object_from_redis(redis_key)
|
143 |
if not response:
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
150 |
return response
|
151 |
|
152 |
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
@@ -158,9 +193,12 @@ def train_model(model, dataset, epochs, batch_size, learning_rate):
|
|
158 |
learning_rate=learning_rate,
|
159 |
)
|
160 |
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
164 |
|
165 |
def run_task(task_queue):
|
166 |
while True:
|
@@ -168,7 +206,10 @@ def run_task(task_queue):
|
|
168 |
if task is None:
|
169 |
break
|
170 |
func, args, kwargs = task
|
171 |
-
|
|
|
|
|
|
|
172 |
|
173 |
task_queue = multiprocessing.Queue()
|
174 |
num_processes = multiprocessing.cpu_count()
|
@@ -179,33 +220,16 @@ for _ in range(num_processes):
|
|
179 |
p.start()
|
180 |
processes.append(p)
|
181 |
|
182 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
183 |
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
|
184 |
-
img2img_pipeline = get_model_or_download("
|
185 |
-
flux_pipeline = get_model_or_download("
|
186 |
-
|
187 |
-
music_gen = load_object_from_redis("music_gen") or MusicGen.
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
193 |
-
device=device,
|
194 |
-
use_auth_token=hf_token,
|
195 |
-
)
|
196 |
-
save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
|
197 |
-
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
|
198 |
-
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
|
199 |
-
meta_llama_pipeline = transformers_pipeline(
|
200 |
-
"text-generation",
|
201 |
-
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
202 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
203 |
-
device_map="auto",
|
204 |
-
use_auth_token=hf_token
|
205 |
-
)
|
206 |
-
|
207 |
-
gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
|
208 |
-
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
|
209 |
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
|
210 |
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
|
211 |
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
|
|
|
16 |
# Obtener las variables de entorno
|
17 |
hf_token = os.getenv("HF_TOKEN")
|
18 |
redis_host = os.getenv("REDIS_HOST")
|
19 |
+
redis_port = int(os.getenv("REDIS_PORT", 6379)) # Valor predeterminado si no se proporciona
|
20 |
redis_password = os.getenv("REDIS_PASSWORD")
|
21 |
|
22 |
HfFolder.save_token(hf_token)
|
|
|
65 |
if model:
|
66 |
print(f"Model loaded from Redis: {redis_key}")
|
67 |
return model
|
68 |
+
try:
|
69 |
+
model = loader_func(model_id, torch_dtype=torch.float16)
|
70 |
+
save_object_to_redis(redis_key, model)
|
71 |
+
print(f"Model downloaded and saved to Redis: {redis_key}")
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Failed to load or save model: {e}")
|
74 |
return model
|
75 |
|
76 |
def generate_image(prompt):
|
77 |
redis_key = f"generated_image_{prompt}"
|
78 |
image = load_object_from_redis(redis_key)
|
79 |
if not image:
|
80 |
+
try:
|
81 |
+
image = text_to_image_pipeline(prompt).images[0]
|
82 |
+
save_object_to_redis(redis_key, image)
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Failed to generate image: {e}")
|
85 |
+
return None
|
86 |
return image
|
87 |
|
88 |
def edit_image_with_prompt(image, prompt, strength=0.75):
|
89 |
redis_key = f"edited_image_{prompt}_{strength}"
|
90 |
edited_image = load_object_from_redis(redis_key)
|
91 |
if not edited_image:
|
92 |
+
try:
|
93 |
+
edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
|
94 |
+
save_object_to_redis(redis_key, edited_image)
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Failed to edit image: {e}")
|
97 |
+
return None
|
98 |
return edited_image
|
99 |
|
100 |
def generate_song(prompt, duration=10):
|
101 |
redis_key = f"generated_song_{prompt}_{duration}"
|
102 |
song = load_object_from_redis(redis_key)
|
103 |
if not song:
|
104 |
+
try:
|
105 |
+
song = music_gen.generate(prompt, duration=duration)
|
106 |
+
save_object_to_redis(redis_key, song)
|
107 |
+
except Exception as e:
|
108 |
+
print(f"Failed to generate song: {e}")
|
109 |
+
return None
|
110 |
return song
|
111 |
|
112 |
def generate_text(prompt):
|
113 |
redis_key = f"generated_text_{prompt}"
|
114 |
text = load_object_from_redis(redis_key)
|
115 |
if not text:
|
116 |
+
try:
|
117 |
+
text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
|
118 |
+
save_object_to_redis(redis_key, text)
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Failed to generate text: {e}")
|
121 |
+
return None
|
122 |
return text
|
123 |
|
124 |
def generate_flux_image(prompt):
|
125 |
redis_key = f"generated_flux_image_{prompt}"
|
126 |
flux_image = load_object_from_redis(redis_key)
|
127 |
if not flux_image:
|
128 |
+
try:
|
129 |
+
flux_image = flux_pipeline(
|
130 |
+
prompt,
|
131 |
+
guidance_scale=0.0,
|
132 |
+
num_inference_steps=4,
|
133 |
+
max_sequence_length=256,
|
134 |
+
generator=torch.Generator("cpu").manual_seed(0)
|
135 |
+
).images[0]
|
136 |
+
save_object_to_redis(redis_key, flux_image)
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Failed to generate flux image: {e}")
|
139 |
+
return None
|
140 |
return flux_image
|
141 |
|
142 |
def generate_code(prompt):
|
143 |
redis_key = f"generated_code_{prompt}"
|
144 |
code = load_object_from_redis(redis_key)
|
145 |
if not code:
|
146 |
+
try:
|
147 |
+
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
148 |
+
outputs = starcoder_model.generate(inputs)
|
149 |
+
code = starcoder_tokenizer.decode(outputs[0])
|
150 |
+
save_object_to_redis(redis_key, code)
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Failed to generate code: {e}")
|
153 |
+
return None
|
154 |
return code
|
155 |
|
156 |
def generate_video(prompt):
|
157 |
redis_key = f"generated_video_{prompt}"
|
158 |
video = load_object_from_redis(redis_key)
|
159 |
if not video:
|
160 |
+
try:
|
161 |
+
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
|
162 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
163 |
+
pipe.enable_model_cpu_offload()
|
164 |
+
video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
165 |
+
save_object_to_redis(redis_key, video)
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Failed to generate video: {e}")
|
168 |
+
return None
|
169 |
return video
|
170 |
|
171 |
def test_model_meta_llama():
|
172 |
redis_key = "meta_llama_test_response"
|
173 |
response = load_object_from_redis(redis_key)
|
174 |
if not response:
|
175 |
+
try:
|
176 |
+
messages = [
|
177 |
+
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
|
178 |
+
{"role": "user", "content": "Who are you?"}
|
179 |
+
]
|
180 |
+
response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
|
181 |
+
save_object_to_redis(redis_key, response)
|
182 |
+
except Exception as e:
|
183 |
+
print(f"Failed to test Meta-Llama: {e}")
|
184 |
+
return None
|
185 |
return response
|
186 |
|
187 |
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
|
|
193 |
learning_rate=learning_rate,
|
194 |
)
|
195 |
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
196 |
+
try:
|
197 |
+
trainer.train()
|
198 |
+
save_object_to_redis("trained_model", model)
|
199 |
+
save_object_to_redis("training_results", output_dir.getvalue())
|
200 |
+
except Exception as e:
|
201 |
+
print(f"Failed to train model: {e}")
|
202 |
|
203 |
def run_task(task_queue):
|
204 |
while True:
|
|
|
206 |
if task is None:
|
207 |
break
|
208 |
func, args, kwargs = task
|
209 |
+
try:
|
210 |
+
func(*args, **kwargs)
|
211 |
+
except Exception as e:
|
212 |
+
print(f"Failed to run task: {e}")
|
213 |
|
214 |
task_queue = multiprocessing.Queue()
|
215 |
num_processes = multiprocessing.cpu_count()
|
|
|
220 |
p.start()
|
221 |
processes.append(p)
|
222 |
|
223 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
224 |
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
|
225 |
+
img2img_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
|
226 |
+
flux_pipeline = get_model_or_download("CompVis/stable-diffusion-flux", "flux_model", FluxPipeline.from_pretrained).to(device)
|
227 |
+
text_gen_pipeline = transformers_pipeline("text-generation", model="bigcode/starcoder", tokenizer="bigcode/starcoder", device=0)
|
228 |
+
music_gen = load_object_from_redis("music_gen") or MusicGen.from_pretrained('melody')
|
229 |
+
meta_llama_pipeline = get_model_or_download("meta/meta-llama-7b", "meta_llama_model", transformers_pipeline)
|
230 |
+
|
231 |
+
gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Image")
|
232 |
+
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
|
234 |
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
|
235 |
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
|