Spaces:
Running
Running
Ffftdtd5dtft
commited on
Upload 2 files
Browse files- app.py +159 -0
- requirements (1).txt +9 -0
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import redis
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
6 |
+
from diffusers.utils import export_to_video
|
7 |
+
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
8 |
+
from audiocraft.models import MusicGen
|
9 |
+
import gradio as gr
|
10 |
+
from huggingface_hub import snapshot_download, HfApi, HfFolder
|
11 |
+
import multiprocessing
|
12 |
+
import io
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
import os
|
15 |
+
|
16 |
+
# Cargar las variables del archivo .env
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
# Obtener las variables de entorno
|
20 |
+
hf_token = os.getenv("HF_TOKEN")
|
21 |
+
redis_host = os.getenv("REDIS_HOST")
|
22 |
+
redis_port = os.getenv("REDIS_PORT")
|
23 |
+
redis_password = os.getenv("REDIS_PASSWORD")
|
24 |
+
|
25 |
+
# Usar las variables de huggingface
|
26 |
+
HfFolder.save_token(hf_token)
|
27 |
+
|
28 |
+
# Usar las variables de redis
|
29 |
+
def connect_to_redis():
|
30 |
+
return redis.Redis(host=redis_host, port=redis_port, password=redis_password)
|
31 |
+
|
32 |
+
def load_object_from_redis(key):
|
33 |
+
with connect_to_redis() as redis_client:
|
34 |
+
obj_data = redis_client.get(key)
|
35 |
+
return pickle.loads(obj_data) if obj_data else None
|
36 |
+
|
37 |
+
def save_object_to_redis(key, obj):
|
38 |
+
with connect_to_redis() as redis_client:
|
39 |
+
redis_client.set(key, pickle.dumps(obj))
|
40 |
+
|
41 |
+
def get_model_or_download(model_id, redis_key, loader_func):
|
42 |
+
model = load_object_from_redis(redis_key)
|
43 |
+
if not model:
|
44 |
+
model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
|
45 |
+
save_object_to_redis(redis_key, model)
|
46 |
+
return model
|
47 |
+
|
48 |
+
def generate_image(prompt):
|
49 |
+
return text_to_image_pipeline(prompt).images[0]
|
50 |
+
|
51 |
+
def edit_image_with_prompt(image, prompt, strength=0.75):
|
52 |
+
return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
|
53 |
+
|
54 |
+
def generate_song(prompt, duration=10):
|
55 |
+
return music_gen.generate(prompt, duration=duration)
|
56 |
+
|
57 |
+
def generate_text(prompt):
|
58 |
+
return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()
|
59 |
+
|
60 |
+
def generate_flux_image(prompt):
|
61 |
+
return flux_pipeline(
|
62 |
+
prompt,
|
63 |
+
guidance_scale=0.0,
|
64 |
+
num_inference_steps=4,
|
65 |
+
max_sequence_length=256,
|
66 |
+
generator=torch.Generator("cpu").manual_seed(0)
|
67 |
+
).images[0]
|
68 |
+
|
69 |
+
def generate_code(prompt):
|
70 |
+
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
71 |
+
outputs = starcoder_model.generate(inputs)
|
72 |
+
return starcoder_tokenizer.decode(outputs[0])
|
73 |
+
|
74 |
+
def generate_video(prompt):
|
75 |
+
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
|
76 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
77 |
+
pipe.enable_model_cpu_offload()
|
78 |
+
return export_to_video(pipe(prompt, num_inference_steps=25).frames)
|
79 |
+
|
80 |
+
def test_model_meta_llama():
|
81 |
+
messages = [
|
82 |
+
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
|
83 |
+
{"role": "user", "content": "Who are you?"}
|
84 |
+
]
|
85 |
+
return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]
|
86 |
+
|
87 |
+
def train_model(model, dataset, epochs, batch_size, learning_rate):
|
88 |
+
output_dir = io.BytesIO()
|
89 |
+
training_args = TrainingArguments(
|
90 |
+
output_dir=output_dir,
|
91 |
+
num_train_epochs=epochs,
|
92 |
+
per_device_train_batch_size=batch_size,
|
93 |
+
learning_rate=learning_rate,
|
94 |
+
)
|
95 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
96 |
+
trainer.train()
|
97 |
+
save_object_to_redis("trained_model", model)
|
98 |
+
save_object_to_redis("training_results", output_dir.getvalue())
|
99 |
+
|
100 |
+
def run_task(task_queue):
|
101 |
+
while True:
|
102 |
+
task = task_queue.get()
|
103 |
+
if task is None:
|
104 |
+
break
|
105 |
+
func, args, kwargs = task
|
106 |
+
func(*args, **kwargs)
|
107 |
+
|
108 |
+
task_queue = multiprocessing.Queue()
|
109 |
+
num_processes = multiprocessing.cpu_count()
|
110 |
+
|
111 |
+
processes = []
|
112 |
+
for _ in range(num_processes):
|
113 |
+
p = multiprocessing.Process(target=run_task, args=(task_queue,))
|
114 |
+
p.start()
|
115 |
+
processes.append(p)
|
116 |
+
|
117 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
118 |
+
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
|
119 |
+
img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
|
120 |
+
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
|
121 |
+
flux_pipeline.enable_model_cpu_offload()
|
122 |
+
music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
|
123 |
+
save_object_to_redis("music_gen", music_gen)
|
124 |
+
text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
|
125 |
+
"text-generation",
|
126 |
+
model="google/gemma-2-2b-it",
|
127 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
128 |
+
device=device,
|
129 |
+
use_auth_token=hf_token,
|
130 |
+
)
|
131 |
+
save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
|
132 |
+
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
|
133 |
+
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
|
134 |
+
meta_llama_pipeline = transformers_pipeline(
|
135 |
+
"text-generation",
|
136 |
+
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
137 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
138 |
+
device_map="auto",
|
139 |
+
use_auth_token=hf_token
|
140 |
+
)
|
141 |
+
|
142 |
+
gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
|
143 |
+
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
|
144 |
+
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
|
145 |
+
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
|
146 |
+
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
|
147 |
+
model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
|
148 |
+
|
149 |
+
app = gr.TabbedInterface(
|
150 |
+
[gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
|
151 |
+
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
|
152 |
+
)
|
153 |
+
|
154 |
+
app.launch(share=True)
|
155 |
+
|
156 |
+
for _ in range(num_processes):
|
157 |
+
task_queue.put(None)
|
158 |
+
for p in processes:
|
159 |
+
p.join()
|
requirements (1).txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python-dotenv
|
2 |
+
redis
|
3 |
+
diffusers
|
4 |
+
transformers
|
5 |
+
accelerate
|
6 |
+
torch
|
7 |
+
gradio
|
8 |
+
audiocraft
|
9 |
+
huggingface_hub
|