Update app.py
Browse files
app.py
CHANGED
@@ -277,7 +277,7 @@ print('Loading custom white-background unet ...')
|
|
277 |
if os.path.exists(infer_config.unet_path):
|
278 |
unet_ckpt_path = infer_config.unet_path
|
279 |
else:
|
280 |
-
unet_ckpt_path = hf_hub_download(repo_id="LTT/
|
281 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
282 |
pipeline.unet.load_state_dict(state_dict, strict=True)
|
283 |
|
@@ -289,7 +289,7 @@ model = instantiate_from_config(model_config)
|
|
289 |
if os.path.exists(infer_config.model_path):
|
290 |
model_ckpt_path = infer_config.model_path
|
291 |
else:
|
292 |
-
model_ckpt_path = hf_hub_download(repo_id="LTT/
|
293 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
294 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
295 |
model.load_state_dict(state_dict, strict=True)
|
|
|
277 |
if os.path.exists(infer_config.unet_path):
|
278 |
unet_ckpt_path = infer_config.unet_path
|
279 |
else:
|
280 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
|
281 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
282 |
pipeline.unet.load_state_dict(state_dict, strict=True)
|
283 |
|
|
|
289 |
if os.path.exists(infer_config.model_path):
|
290 |
model_ckpt_path = infer_config.model_path
|
291 |
else:
|
292 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
293 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
294 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
295 |
model.load_state_dict(state_dict, strict=True)
|