Linoy Tsaban commited on
Commit
c0ff22d
·
1 Parent(s): fb4ae64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -158,6 +158,9 @@ def load_and_invert(
158
  # x0 = load_512(input_image, device=device).to(torch.float16)
159
 
160
  if do_inversion or randomize_seed:
 
 
 
161
  # invert and retrieve noise maps and latent
162
  zs_tensor, wts_tensor = pipe.invert(
163
  image_path = input_image,
@@ -206,6 +209,11 @@ def edit(input_image,
206
  elif(mask_type=="Intersect Mask"):
207
  use_cross_attn_mask = False
208
  use_intersect_mask = True
 
 
 
 
 
209
  if do_inversion or randomize_seed:
210
  zs_tensor, wts_tensor = pipe.invert(
211
  image_path = input_image,
@@ -259,12 +267,16 @@ def edit(input_image,
259
  return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
260
 
261
 
262
- def randomize_seed_fn(seed, randomize_seed):
263
- if randomize_seed:
264
- seed = random.randint(0, np.iinfo(np.int32).max)
265
- torch.manual_seed(seed)
266
  return seed
267
 
 
 
 
 
 
 
268
  def crop_image(image):
269
  h, w, c = image.shape
270
  if h < w:
 
158
  # x0 = load_512(input_image, device=device).to(torch.float16)
159
 
160
  if do_inversion or randomize_seed:
161
+ if randomize_seed:
162
+ seed = randomize_seed_fn()
163
+ seed_everything(seed)
164
  # invert and retrieve noise maps and latent
165
  zs_tensor, wts_tensor = pipe.invert(
166
  image_path = input_image,
 
209
  elif(mask_type=="Intersect Mask"):
210
  use_cross_attn_mask = False
211
  use_intersect_mask = True
212
+
213
+ if randomize_seed:
214
+ seed = randomize_seed_fn()
215
+ seed_everything(seed)
216
+
217
  if do_inversion or randomize_seed:
218
  zs_tensor, wts_tensor = pipe.invert(
219
  image_path = input_image,
 
267
  return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
268
 
269
 
270
+ def randomize_seed_fn():
271
+ seed = random.randint(0, np.iinfo(np.int32).max)
 
 
272
  return seed
273
 
274
+ def seed_everything(seed):
275
+ torch.manual_seed(seed)
276
+ torch.cuda.manual_seed(seed)
277
+ random.seed(seed)
278
+ np.random.seed(seed)
279
+
280
  def crop_image(image):
281
  h, w, c = image.shape
282
  if h < w: