Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,21 +2,24 @@ import torch
|
|
2 |
from diffusers import DiffusionPipeline, DDPMScheduler
|
3 |
from accelerate import Accelerator
|
4 |
from datasets import load_dataset
|
5 |
-
from huggingface_hub import HfFolder, Repository, whoami
|
6 |
from tqdm.auto import tqdm
|
7 |
from transformers import TrainingArguments
|
|
|
8 |
|
9 |
# Konfigurasi
|
10 |
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
11 |
-
dataset_name = "DucHaiten/anime-SDXL"
|
12 |
learning_rate = 1e-5
|
13 |
-
num_train_epochs =
|
14 |
-
train_batch_size =
|
15 |
-
gradient_accumulation_steps =
|
16 |
-
output_dir = "flux-anime"
|
|
|
17 |
|
18 |
# Muat model dan scheduler
|
19 |
-
pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
|
20 |
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
|
21 |
pipeline.enable_xformers_memory_efficient_attention()
|
22 |
|
@@ -25,9 +28,14 @@ dataset = load_dataset(dataset_name)["train"]
|
|
25 |
|
26 |
# Fungsi untuk memproses data
|
27 |
def preprocess_function(examples):
|
28 |
-
images = [
|
|
|
|
|
|
|
29 |
texts = [text for text in examples["text"]]
|
30 |
-
examples["pixel_values"] = pipeline.feature_extractor(
|
|
|
|
|
31 |
examples["prompt"] = texts
|
32 |
return examples
|
33 |
|
@@ -62,34 +70,82 @@ training_args = TrainingArguments(
|
|
62 |
num_train_epochs=num_train_epochs,
|
63 |
learning_rate=learning_rate,
|
64 |
fp16=True,
|
65 |
-
logging_dir="./logs",
|
66 |
-
report_to="tensorboard"
|
|
|
67 |
)
|
68 |
|
69 |
# Training loop
|
70 |
-
progress_bar = tqdm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for epoch in range(num_train_epochs):
|
72 |
pipeline.unet.train()
|
73 |
-
for step, batch in enumerate(
|
|
|
|
|
74 |
with accelerator.accumulate(pipeline.unet):
|
75 |
-
latents = pipeline.vae.encode(
|
|
|
|
|
76 |
latents = latents * pipeline.vae.config.scaling_factor
|
77 |
noise = torch.randn_like(latents)
|
78 |
bsz = latents.shape[0]
|
79 |
-
timesteps = torch.randint(
|
|
|
|
|
|
|
|
|
|
|
80 |
timesteps = timesteps.long()
|
81 |
-
noisy_latents = pipeline.scheduler.add_noise(
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
accelerator.backward(loss)
|
88 |
optimizer.step()
|
89 |
optimizer.zero_grad()
|
90 |
progress_bar.update(1)
|
91 |
|
|
|
|
|
|
|
92 |
# Simpan model
|
93 |
pipeline.save_pretrained(output_dir)
|
94 |
-
repo.push_to_hub(commit_message=f"Training epoch {epoch}", blocking=False)
|
95 |
|
|
|
2 |
from diffusers import DiffusionPipeline, DDPMScheduler
|
3 |
from accelerate import Accelerator
|
4 |
from datasets import load_dataset
|
|
|
5 |
from tqdm.auto import tqdm
|
6 |
from transformers import TrainingArguments
|
7 |
+
import gradio as gr
|
8 |
|
9 |
# Konfigurasi
|
10 |
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
11 |
+
dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda
|
12 |
learning_rate = 1e-5
|
13 |
+
num_train_epochs = 2 # Sesuaikan dengan kebutuhan
|
14 |
+
train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis
|
15 |
+
gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan
|
16 |
+
output_dir = "flux-anime"
|
17 |
+
image_resize = 128 # Sesuaikan dengan kebutuhan
|
18 |
|
19 |
# Muat model dan scheduler
|
20 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
21 |
+
pretrained_model_name_or_path, torch_dtype=torch.float16
|
22 |
+
)
|
23 |
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
|
24 |
pipeline.enable_xformers_memory_efficient_attention()
|
25 |
|
|
|
28 |
|
29 |
# Fungsi untuk memproses data
|
30 |
def preprocess_function(examples):
|
31 |
+
images = [
|
32 |
+
image.convert("RGB").resize((image_resize, image_resize))
|
33 |
+
for image in examples["image"]
|
34 |
+
]
|
35 |
texts = [text for text in examples["text"]]
|
36 |
+
examples["pixel_values"] = pipeline.feature_extractor(
|
37 |
+
images=images, return_tensors="pt"
|
38 |
+
).pixel_values
|
39 |
examples["prompt"] = texts
|
40 |
return examples
|
41 |
|
|
|
70 |
num_train_epochs=num_train_epochs,
|
71 |
learning_rate=learning_rate,
|
72 |
fp16=True,
|
73 |
+
logging_dir="./logs",
|
74 |
+
report_to="tensorboard",
|
75 |
+
push_to_hub=True, # Push model ke Hugging Face Hub
|
76 |
)
|
77 |
|
78 |
# Training loop
|
79 |
+
progress_bar = tqdm(
|
80 |
+
range(num_train_epochs * len(processed_dataset) // train_batch_size)
|
81 |
+
)
|
82 |
+
|
83 |
+
# --- Komponen Gradio ---
|
84 |
+
with gr.Blocks() as interface:
|
85 |
+
gr.Markdown(
|
86 |
+
"## Fine-tuning FLUX untuk Anime"
|
87 |
+
) # Ganti judul sesuai dataset Anda
|
88 |
+
loss_textbox = gr.Textbox(label="Loss")
|
89 |
+
epoch_textbox = gr.Textbox(label="Epoch")
|
90 |
+
progress_bar_gradio = gr.ProgressBar(label="Progress")
|
91 |
+
output_image = gr.Image(label="Generated Image")
|
92 |
+
|
93 |
+
def train_step(step, epoch, loss):
|
94 |
+
loss_textbox.update(value=loss)
|
95 |
+
epoch_textbox.update(value=epoch)
|
96 |
+
progress_bar_gradio.update(value=step / len(progress_bar))
|
97 |
+
if step % 100 == 0:
|
98 |
+
with torch.no_grad():
|
99 |
+
image = pipeline(
|
100 |
+
"anime style image of a girl with blue hair"
|
101 |
+
).images[
|
102 |
+
0
|
103 |
+
] # Ganti prompt sesuai dataset Anda
|
104 |
+
output_image.update(value=image)
|
105 |
+
return loss, epoch, step / len(progress_bar)
|
106 |
+
|
107 |
+
interface.launch(server_name="0.0.0.0")
|
108 |
+
|
109 |
+
# ------------------------
|
110 |
+
|
111 |
for epoch in range(num_train_epochs):
|
112 |
pipeline.unet.train()
|
113 |
+
for step, batch in enumerate(
|
114 |
+
processed_dataset.iter(batch_size=train_batch_size)
|
115 |
+
):
|
116 |
with accelerator.accumulate(pipeline.unet):
|
117 |
+
latents = pipeline.vae.encode(
|
118 |
+
batch["pixel_values"].to(dtype=torch.float16)
|
119 |
+
).latent_dist.sample()
|
120 |
latents = latents * pipeline.vae.config.scaling_factor
|
121 |
noise = torch.randn_like(latents)
|
122 |
bsz = latents.shape[0]
|
123 |
+
timesteps = torch.randint(
|
124 |
+
0,
|
125 |
+
pipeline.scheduler.config.num_train_timesteps,
|
126 |
+
(bsz,),
|
127 |
+
device=latents.device,
|
128 |
+
)
|
129 |
timesteps = timesteps.long()
|
130 |
+
noisy_latents = pipeline.scheduler.add_noise(
|
131 |
+
latents, noise, timesteps
|
132 |
+
)
|
133 |
+
|
134 |
+
model_pred = pipeline.unet(
|
135 |
+
noisy_latents, timesteps, batch["prompt"]
|
136 |
+
).sample
|
137 |
+
|
138 |
+
loss = torch.nn.functional.mse_loss(
|
139 |
+
model_pred.float(), noise.float(), reduction="mean"
|
140 |
+
)
|
141 |
accelerator.backward(loss)
|
142 |
optimizer.step()
|
143 |
optimizer.zero_grad()
|
144 |
progress_bar.update(1)
|
145 |
|
146 |
+
# Update komponen Gradio
|
147 |
+
train_step(step, epoch, loss.item())
|
148 |
+
|
149 |
# Simpan model
|
150 |
pipeline.save_pretrained(output_dir)
|
|
|
151 |
|