pcuenq HF staff commited on
Commit
5e43d22
·
1 Parent(s): 2e148a1

Run model and prior in half precision.

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -118,20 +118,17 @@ def decode(img_seq, shape=(32,32)):
118
  return img
119
 
120
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
121
- state_dict = torch.load(model_path, map_location=device)
122
- model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1]).to(device)
123
- model.load_state_dict(state_dict)
124
  model.eval().requires_grad_()
125
 
126
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
127
- prior_ckpt = torch.load(prior_path, map_location=device)
128
- prior = PriorModel().to(device)
129
- prior.load_state_dict(prior_ckpt)
130
  prior.eval().requires_grad_(False)
131
  diffuzz = Diffuzz(device=device)
132
 
133
- del prior_ckpt, state_dict
134
-
135
  # -----
136
 
137
  def infer(prompt, negative_prompt):
 
118
  return img
119
 
120
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
121
+ model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1])
122
+ model = model.to(device).half()
123
+ model.load_state_dict(torch.load(model_path, map_location=device))
124
  model.eval().requires_grad_()
125
 
126
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
127
+ prior = PriorModel().to(device).half()
128
+ prior.load_state_dict(torch.load(prior_path, map_location=device))
 
129
  prior.eval().requires_grad_(False)
130
  diffuzz = Diffuzz(device=device)
131
 
 
 
132
  # -----
133
 
134
  def infer(prompt, negative_prompt):