Spaces:
Runtime error
Runtime error
import torch | |
from diffusers import DiffusionPipeline, DDPMScheduler | |
from accelerate import Accelerator | |
from datasets import load_dataset | |
from tqdm.auto import tqdm | |
from transformers import TrainingArguments | |
import gradio as gr | |
# Konfigurasi | |
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" | |
dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda | |
learning_rate = 1e-5 | |
num_train_epochs = 2 # Sesuaikan dengan kebutuhan | |
train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis | |
gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan | |
output_dir = "flux-anime" | |
image_resize = 128 # Sesuaikan dengan kebutuhan | |
# Muat model dan scheduler | |
pipeline = DiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path, torch_dtype=torch.float16 | |
) | |
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) | |
pipeline.enable_xformers_memory_efficient_attention() | |
# Muat dataset | |
dataset = load_dataset(dataset_name)["train"] | |
# Fungsi untuk memproses data | |
def preprocess_function(examples): | |
images = [ | |
image.convert("RGB").resize((image_resize, image_resize)) | |
for image in examples["image"] | |
] | |
texts = [text for text in examples["text"]] | |
examples["pixel_values"] = pipeline.feature_extractor( | |
images=images, return_tensors="pt" | |
).pixel_values | |
examples["prompt"] = texts | |
return examples | |
# Proses dataset | |
processed_dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
num_proc=4, | |
remove_columns=dataset.column_names, | |
) | |
# Inisialisasi accelerator | |
accelerator = Accelerator( | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
mixed_precision="fp16", | |
) | |
pipeline.unet, pipeline.vae, processed_dataset = accelerator.prepare( | |
pipeline.unet, pipeline.vae, processed_dataset | |
) | |
# Optimizer | |
optimizer = torch.optim.AdamW( | |
pipeline.unet.parameters(), | |
lr=learning_rate, | |
) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
per_device_train_batch_size=train_batch_size, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
num_train_epochs=num_train_epochs, | |
learning_rate=learning_rate, | |
fp16=True, | |
logging_dir="./logs", | |
report_to="tensorboard", | |
push_to_hub=True, # Push model ke Hugging Face Hub | |
) | |
# Training loop | |
progress_bar = tqdm( | |
range(num_train_epochs * len(processed_dataset) // train_batch_size) | |
) | |
# --- Komponen Gradio --- | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
"## Fine-tuning FLUX untuk Anime" | |
) # Ganti judul sesuai dataset Anda | |
loss_textbox = gr.Textbox(label="Loss") | |
epoch_textbox = gr.Textbox(label="Epoch") | |
progress_bar_gradio = gr.ProgressBar(label="Progress") | |
output_image = gr.Image(label="Generated Image") | |
def train_step(step, epoch, loss): | |
loss_textbox.update(value=loss) | |
epoch_textbox.update(value=epoch) | |
progress_bar_gradio.update(value=step / len(progress_bar)) | |
if step % 100 == 0: | |
with torch.no_grad(): | |
image = pipeline( | |
"anime style image of a girl with blue hair" | |
).images[ | |
0 | |
] # Ganti prompt sesuai dataset Anda | |
output_image.update(value=image) | |
return loss, epoch, step / len(progress_bar) | |
interface.launch(server_name="0.0.0.0") | |
# ------------------------ | |
for epoch in range(num_train_epochs): | |
pipeline.unet.train() | |
for step, batch in enumerate( | |
processed_dataset.iter(batch_size=train_batch_size) | |
): | |
with accelerator.accumulate(pipeline.unet): | |
latents = pipeline.vae.encode( | |
batch["pixel_values"].to(dtype=torch.float16) | |
).latent_dist.sample() | |
latents = latents * pipeline.vae.config.scaling_factor | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
timesteps = torch.randint( | |
0, | |
pipeline.scheduler.config.num_train_timesteps, | |
(bsz,), | |
device=latents.device, | |
) | |
timesteps = timesteps.long() | |
noisy_latents = pipeline.scheduler.add_noise( | |
latents, noise, timesteps | |
) | |
model_pred = pipeline.unet( | |
noisy_latents, timesteps, batch["prompt"] | |
).sample | |
loss = torch.nn.functional.mse_loss( | |
model_pred.float(), noise.float(), reduction="mean" | |
) | |
accelerator.backward(loss) | |
optimizer.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
# Update komponen Gradio | |
train_step(step, epoch, loss.item()) | |
# Simpan model | |
pipeline.save_pretrained(output_dir) | |