bluestarburst
commited on
Commit
·
00e8857
1
Parent(s):
9a6a590
Upload folder using huggingface_hub
Browse files- handler.py +8 -3
- train.py +14 -9
handler.py
CHANGED
@@ -10,6 +10,7 @@ import os
|
|
10 |
from diffusers.utils.import_utils import is_xformers_available
|
11 |
from typing import Any
|
12 |
import torch
|
|
|
13 |
import torchvision
|
14 |
import numpy as np
|
15 |
from einops import rearrange
|
@@ -101,10 +102,14 @@ class EndpointHandler():
|
|
101 |
x = (x * 255).numpy().astype(np.uint8)
|
102 |
outputs.append(x)
|
103 |
|
104 |
-
|
|
|
105 |
|
106 |
-
#
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
# This is the entry point for the serverless function.
|
|
|
10 |
from diffusers.utils.import_utils import is_xformers_available
|
11 |
from typing import Any
|
12 |
import torch
|
13 |
+
import imageio
|
14 |
import torchvision
|
15 |
import numpy as np
|
16 |
from einops import rearrange
|
|
|
102 |
x = (x * 255).numpy().astype(np.uint8)
|
103 |
outputs.append(x)
|
104 |
|
105 |
+
path = "output.gif"
|
106 |
+
imageio.mimsave(path, outputs, fps=fps)
|
107 |
|
108 |
+
# open the file as binary and read the data
|
109 |
+
with open(path, mode="rb") as file:
|
110 |
+
fileContent = file.read()
|
111 |
+
# return json response with binary data
|
112 |
+
return fileContent
|
113 |
|
114 |
|
115 |
# This is the entry point for the serverless function.
|
train.py
CHANGED
@@ -321,6 +321,7 @@ def main(
|
|
321 |
# Only show the progress bar once on each machine.
|
322 |
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
323 |
progress_bar.set_description("Steps")
|
|
|
324 |
|
325 |
for epoch in range(first_epoch, num_train_epochs):
|
326 |
unet.train()
|
@@ -363,28 +364,32 @@ def main(
|
|
363 |
else:
|
364 |
raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
|
365 |
|
|
|
366 |
# Predict the noise residual and compute loss
|
367 |
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
368 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
369 |
|
370 |
# Gather the losses across all processes for logging (if we use distributed training).
|
371 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
372 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
373 |
|
374 |
-
|
375 |
-
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
376 |
-
for params in module.parameters():
|
377 |
-
params.requires_grad = True
|
378 |
|
379 |
# Backpropagate
|
380 |
-
accelerator.backward(loss)
|
|
|
|
|
|
|
|
|
381 |
if accelerator.sync_gradients:
|
382 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
|
|
388 |
|
389 |
optimizer.step()
|
390 |
lr_scheduler.step()
|
|
|
321 |
# Only show the progress bar once on each machine.
|
322 |
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
323 |
progress_bar.set_description("Steps")
|
324 |
+
optimizer.zero_grad()
|
325 |
|
326 |
for epoch in range(first_epoch, num_train_epochs):
|
327 |
unet.train()
|
|
|
364 |
else:
|
365 |
raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
|
366 |
|
367 |
+
|
368 |
# Predict the noise residual and compute loss
|
369 |
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
370 |
+
print("Model Output:", model_pred)
|
371 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
372 |
|
373 |
# Gather the losses across all processes for logging (if we use distributed training).
|
374 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
375 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
376 |
|
377 |
+
print("Loss:", loss)
|
|
|
|
|
|
|
378 |
|
379 |
# Backpropagate
|
380 |
+
# accelerator.backward(loss)
|
381 |
+
|
382 |
+
with accelerator.scaler.scale_loss(loss) as scaled_loss:
|
383 |
+
scaled_loss.backward()
|
384 |
+
|
385 |
if accelerator.sync_gradients:
|
386 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
387 |
|
388 |
+
print("grad: ")
|
389 |
+
for param in unet.parameters():
|
390 |
+
if param.grad is not None:
|
391 |
+
print(param.grad)
|
392 |
+
break
|
393 |
|
394 |
optimizer.step()
|
395 |
lr_scheduler.step()
|