Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -46,10 +46,11 @@ class ModelWrapper:
|
|
46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
-
|
|
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
52 |
-
state_dict = torch.load(checkpoint_path, map_location="
|
53 |
generator.load_state_dict(state_dict, strict=True)
|
54 |
generator.requires_grad_(False)
|
55 |
return generator
|
@@ -108,7 +109,7 @@ class ModelWrapper:
|
|
108 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
109 |
return eval_images
|
110 |
|
111 |
-
|
112 |
@torch.no_grad()
|
113 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
114 |
print("Running model inference...")
|
@@ -196,9 +197,6 @@ def create_demo():
|
|
196 |
num_step = 4
|
197 |
revision = None
|
198 |
|
199 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
200 |
-
torch.backends.cudnn.allow_tf32 = True
|
201 |
-
|
202 |
accelerator = Accelerator()
|
203 |
|
204 |
model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
|
@@ -211,10 +209,10 @@ def create_demo():
|
|
211 |
run_button = gr.Button("Run")
|
212 |
with gr.Accordion(label="Advanced options", open=True):
|
213 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
|
214 |
-
num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=
|
215 |
fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
|
216 |
-
height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=
|
217 |
-
width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=
|
218 |
with gr.Column():
|
219 |
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
|
220 |
error_message = gr.Text(label="Job Status")
|
|
|
46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
+
|
50 |
+
@spaces.GPU(enable_queue=True)
|
51 |
def create_generator(self, model_id, checkpoint_path):
|
52 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
53 |
+
state_dict = torch.load(checkpoint_path, map_location="cuda")
|
54 |
generator.load_state_dict(state_dict, strict=True)
|
55 |
generator.requires_grad_(False)
|
56 |
return generator
|
|
|
109 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
110 |
return eval_images
|
111 |
|
112 |
+
|
113 |
@torch.no_grad()
|
114 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
115 |
print("Running model inference...")
|
|
|
197 |
num_step = 4
|
198 |
revision = None
|
199 |
|
|
|
|
|
|
|
200 |
accelerator = Accelerator()
|
201 |
|
202 |
model = ModelWrapper(model_id, checkpoint_path, precision, image_resolution, latent_resolution, num_train_timesteps, conditioning_timestep, num_step, revision, accelerator)
|
|
|
209 |
run_button = gr.Button("Run")
|
210 |
with gr.Accordion(label="Advanced options", open=True):
|
211 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=1000000, step=1, value=0)
|
212 |
+
num_images = gr.Slider(label="Number of generated images", minimum=1, maximum=16, step=1, value=1)
|
213 |
fast_vae_decode = gr.Checkbox(label="Use Tiny VAE for faster decoding", value=True)
|
214 |
+
height = gr.Slider(label="Image Height", minimum=512, maximum=1536, step=64, value=512)
|
215 |
+
width = gr.Slider(label="Image Width", minimum=512, maximum=1536, step=64, value=512)
|
216 |
with gr.Column():
|
217 |
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)
|
218 |
error_message = gr.Text(label="Job Status")
|