attention-refocusing commited on
Commit
ff2c5f2
Β·
1 Parent(s): f412da9

Update gligen/ldm/models/diffusion/plms.py

Browse files
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -151,13 +151,14 @@ class PLMSSampler(object):
151
  object_positions=object_positions, t = index1)*loss_scale
152
  loss = loss1 + loss2
153
  print('loss', loss, loss1, loss2)
154
- hh = torch.autograd.backward(loss, retain_graph=True)
155
- grad_cond = x.grad
 
156
  x = x - grad_cond
157
  x = x.detach()
158
  iteration += 1
159
 
160
- torch.cuda.empty_cache()
161
  return x
162
 
163
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
 
151
  object_positions=object_positions, t = index1)*loss_scale
152
  loss = loss1 + loss2
153
  print('loss', loss, loss1, loss2)
154
+ # hh = torch.autograd.backward(loss, retain_graph=True)
155
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
156
+ # grad_cond = x.grad
157
  x = x - grad_cond
158
  x = x.detach()
159
  iteration += 1
160
 
161
+
162
  return x
163
 
164
  def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):