ayobelajar / app.py
Cun-Duck's picture
Update app.py
6d1d75b verified
raw
history blame
4.8 kB
import torch
from diffusers import DiffusionPipeline, DDPMScheduler
from accelerate import Accelerator
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import TrainingArguments
import gradio as gr
# Konfigurasi
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda
learning_rate = 1e-5
num_train_epochs = 2 # Sesuaikan dengan kebutuhan
train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis
gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan
output_dir = "flux-anime"
image_resize = 128 # Sesuaikan dengan kebutuhan
# Muat model dan scheduler
pipeline = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch.float16
)
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_xformers_memory_efficient_attention()
# Muat dataset
dataset = load_dataset(dataset_name)["train"]
# Fungsi untuk memproses data
def preprocess_function(examples):
images = [
image.convert("RGB").resize((image_resize, image_resize))
for image in examples["image"]
]
texts = [text for text in examples["text"]]
examples["pixel_values"] = pipeline.feature_extractor(
images=images, return_tensors="pt"
).pixel_values
examples["prompt"] = texts
return examples
# Proses dataset
processed_dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=4,
remove_columns=dataset.column_names,
)
# Inisialisasi accelerator
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision="fp16",
)
pipeline.unet, pipeline.vae, processed_dataset = accelerator.prepare(
pipeline.unet, pipeline.vae, processed_dataset
)
# Optimizer
optimizer = torch.optim.AdamW(
pipeline.unet.parameters(),
lr=learning_rate,
)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
fp16=True,
logging_dir="./logs",
report_to="tensorboard",
push_to_hub=True, # Push model ke Hugging Face Hub
)
# Training loop
progress_bar = tqdm(
range(num_train_epochs * len(processed_dataset) // train_batch_size)
)
# --- Komponen Gradio ---
with gr.Blocks() as interface:
gr.Markdown(
"## Fine-tuning FLUX untuk Anime"
) # Ganti judul sesuai dataset Anda
loss_textbox = gr.Textbox(label="Loss")
epoch_textbox = gr.Textbox(label="Epoch")
progress_bar_gradio = gr.ProgressBar(label="Progress")
output_image = gr.Image(label="Generated Image")
def train_step(step, epoch, loss):
loss_textbox.update(value=loss)
epoch_textbox.update(value=epoch)
progress_bar_gradio.update(value=step / len(progress_bar))
if step % 100 == 0:
with torch.no_grad():
image = pipeline(
"anime style image of a girl with blue hair"
).images[
0
] # Ganti prompt sesuai dataset Anda
output_image.update(value=image)
return loss, epoch, step / len(progress_bar)
interface.launch(server_name="0.0.0.0")
# ------------------------
for epoch in range(num_train_epochs):
pipeline.unet.train()
for step, batch in enumerate(
processed_dataset.iter(batch_size=train_batch_size)
):
with accelerator.accumulate(pipeline.unet):
latents = pipeline.vae.encode(
batch["pixel_values"].to(dtype=torch.float16)
).latent_dist.sample()
latents = latents * pipeline.vae.config.scaling_factor
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(
0,
pipeline.scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
noisy_latents = pipeline.scheduler.add_noise(
latents, noise, timesteps
)
model_pred = pipeline.unet(
noisy_latents, timesteps, batch["prompt"]
).sample
loss = torch.nn.functional.mse_loss(
model_pred.float(), noise.float(), reduction="mean"
)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
# Update komponen Gradio
train_step(step, epoch, loss.item())
# Simpan model
pipeline.save_pretrained(output_dir)