Cun-Duck commited on
Commit
6d1d75b
·
verified ·
1 Parent(s): bc5e1f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -22
app.py CHANGED
@@ -2,21 +2,24 @@ import torch
2
  from diffusers import DiffusionPipeline, DDPMScheduler
3
  from accelerate import Accelerator
4
  from datasets import load_dataset
5
- from huggingface_hub import HfFolder, Repository, whoami
6
  from tqdm.auto import tqdm
7
  from transformers import TrainingArguments
 
8
 
9
  # Konfigurasi
10
  pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
11
- dataset_name = "DucHaiten/anime-SDXL"
12
  learning_rate = 1e-5
13
- num_train_epochs = 3
14
- train_batch_size = 4
15
- gradient_accumulation_steps = 2
16
- output_dir = "flux-anime"
 
17
 
18
  # Muat model dan scheduler
19
- pipeline = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
 
 
20
  pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
21
  pipeline.enable_xformers_memory_efficient_attention()
22
 
@@ -25,9 +28,14 @@ dataset = load_dataset(dataset_name)["train"]
25
 
26
  # Fungsi untuk memproses data
27
  def preprocess_function(examples):
28
- images = [image.convert("RGB") for image in examples["image"]]
 
 
 
29
  texts = [text for text in examples["text"]]
30
- examples["pixel_values"] = pipeline.feature_extractor(images=images, return_tensors="pt").pixel_values
 
 
31
  examples["prompt"] = texts
32
  return examples
33
 
@@ -62,34 +70,82 @@ training_args = TrainingArguments(
62
  num_train_epochs=num_train_epochs,
63
  learning_rate=learning_rate,
64
  fp16=True,
65
- logging_dir="./logs", # Direktori untuk menyimpan log TensorBoard
66
- report_to="tensorboard" # Aktifkan logging ke TensorBoard
 
67
  )
68
 
69
  # Training loop
70
- progress_bar = tqdm(range(num_train_epochs * len(processed_dataset) // train_batch_size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  for epoch in range(num_train_epochs):
72
  pipeline.unet.train()
73
- for step, batch in enumerate(processed_dataset.iter(batch_size=train_batch_size)):
 
 
74
  with accelerator.accumulate(pipeline.unet):
75
- latents = pipeline.vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
 
 
76
  latents = latents * pipeline.vae.config.scaling_factor
77
  noise = torch.randn_like(latents)
78
  bsz = latents.shape[0]
79
- timesteps = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
 
 
 
 
 
80
  timesteps = timesteps.long()
81
- noisy_latents = pipeline.scheduler.add_noise(latents, noise, timesteps)
82
-
83
- # Tidak ada text encoder di FLUX
84
- model_pred = pipeline.unet(noisy_latents, timesteps, batch["prompt"]).sample
85
-
86
- loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
 
 
 
 
 
87
  accelerator.backward(loss)
88
  optimizer.step()
89
  optimizer.zero_grad()
90
  progress_bar.update(1)
91
 
 
 
 
92
  # Simpan model
93
  pipeline.save_pretrained(output_dir)
94
- repo.push_to_hub(commit_message=f"Training epoch {epoch}", blocking=False)
95
 
 
2
  from diffusers import DiffusionPipeline, DDPMScheduler
3
  from accelerate import Accelerator
4
  from datasets import load_dataset
 
5
  from tqdm.auto import tqdm
6
  from transformers import TrainingArguments
7
+ import gradio as gr
8
 
9
  # Konfigurasi
10
  pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
11
+ dataset_name = "DucHaiten/anime-SDXL" # Gunakan dataset sesuai keinginan Anda
12
  learning_rate = 1e-5
13
+ num_train_epochs = 2 # Sesuaikan dengan kebutuhan
14
+ train_batch_size = 1 # Gunakan batch size kecil untuk Spaces gratis
15
+ gradient_accumulation_steps = 4 # Sesuaikan dengan kebutuhan
16
+ output_dir = "flux-anime"
17
+ image_resize = 128 # Sesuaikan dengan kebutuhan
18
 
19
  # Muat model dan scheduler
20
+ pipeline = DiffusionPipeline.from_pretrained(
21
+ pretrained_model_name_or_path, torch_dtype=torch.float16
22
+ )
23
  pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
24
  pipeline.enable_xformers_memory_efficient_attention()
25
 
 
28
 
29
  # Fungsi untuk memproses data
30
  def preprocess_function(examples):
31
+ images = [
32
+ image.convert("RGB").resize((image_resize, image_resize))
33
+ for image in examples["image"]
34
+ ]
35
  texts = [text for text in examples["text"]]
36
+ examples["pixel_values"] = pipeline.feature_extractor(
37
+ images=images, return_tensors="pt"
38
+ ).pixel_values
39
  examples["prompt"] = texts
40
  return examples
41
 
 
70
  num_train_epochs=num_train_epochs,
71
  learning_rate=learning_rate,
72
  fp16=True,
73
+ logging_dir="./logs",
74
+ report_to="tensorboard",
75
+ push_to_hub=True, # Push model ke Hugging Face Hub
76
  )
77
 
78
  # Training loop
79
+ progress_bar = tqdm(
80
+ range(num_train_epochs * len(processed_dataset) // train_batch_size)
81
+ )
82
+
83
+ # --- Komponen Gradio ---
84
+ with gr.Blocks() as interface:
85
+ gr.Markdown(
86
+ "## Fine-tuning FLUX untuk Anime"
87
+ ) # Ganti judul sesuai dataset Anda
88
+ loss_textbox = gr.Textbox(label="Loss")
89
+ epoch_textbox = gr.Textbox(label="Epoch")
90
+ progress_bar_gradio = gr.ProgressBar(label="Progress")
91
+ output_image = gr.Image(label="Generated Image")
92
+
93
+ def train_step(step, epoch, loss):
94
+ loss_textbox.update(value=loss)
95
+ epoch_textbox.update(value=epoch)
96
+ progress_bar_gradio.update(value=step / len(progress_bar))
97
+ if step % 100 == 0:
98
+ with torch.no_grad():
99
+ image = pipeline(
100
+ "anime style image of a girl with blue hair"
101
+ ).images[
102
+ 0
103
+ ] # Ganti prompt sesuai dataset Anda
104
+ output_image.update(value=image)
105
+ return loss, epoch, step / len(progress_bar)
106
+
107
+ interface.launch(server_name="0.0.0.0")
108
+
109
+ # ------------------------
110
+
111
  for epoch in range(num_train_epochs):
112
  pipeline.unet.train()
113
+ for step, batch in enumerate(
114
+ processed_dataset.iter(batch_size=train_batch_size)
115
+ ):
116
  with accelerator.accumulate(pipeline.unet):
117
+ latents = pipeline.vae.encode(
118
+ batch["pixel_values"].to(dtype=torch.float16)
119
+ ).latent_dist.sample()
120
  latents = latents * pipeline.vae.config.scaling_factor
121
  noise = torch.randn_like(latents)
122
  bsz = latents.shape[0]
123
+ timesteps = torch.randint(
124
+ 0,
125
+ pipeline.scheduler.config.num_train_timesteps,
126
+ (bsz,),
127
+ device=latents.device,
128
+ )
129
  timesteps = timesteps.long()
130
+ noisy_latents = pipeline.scheduler.add_noise(
131
+ latents, noise, timesteps
132
+ )
133
+
134
+ model_pred = pipeline.unet(
135
+ noisy_latents, timesteps, batch["prompt"]
136
+ ).sample
137
+
138
+ loss = torch.nn.functional.mse_loss(
139
+ model_pred.float(), noise.float(), reduction="mean"
140
+ )
141
  accelerator.backward(loss)
142
  optimizer.step()
143
  optimizer.zero_grad()
144
  progress_bar.update(1)
145
 
146
+ # Update komponen Gradio
147
+ train_step(step, epoch, loss.item())
148
+
149
  # Simpan model
150
  pipeline.save_pretrained(output_dir)
 
151