Jordan Legg commited on
Commit
4153232
1 Parent(s): fa4f33d

tried to fix seeds in local optima

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -4,13 +4,12 @@ import numpy as np
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
- from numpy.random import PCG64DXSM, Generator # Add Generator import
8
  from typing import Tuple, Any
9
 
10
  dtype: torch.dtype = torch.bfloat16
11
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
12
- MAX_SEED = np.iinfo(np.int32).max
13
- rng = Generator(PCG64DXSM()) # Create a Generator instance instead of using PCG64DXSM directly
14
 
15
  pipe = DiffusionPipeline.from_pretrained("shuttleai/shuttle-3-diffusion", torch_dtype=dtype).to(device)
16
  # Enable VAE tiling
@@ -61,6 +60,13 @@ def validate_aspect_ratio(ratio_name: str) -> float | None:
61
  case _:
62
  return None
63
 
 
 
 
 
 
 
 
64
  @spaces.GPU()
65
  def infer(
66
  prompt: str,
@@ -75,7 +81,7 @@ def infer(
75
  FULL_PROMPT = f"{STYLE_PROMPT} {prompt}"
76
 
77
  if randomize_seed:
78
- seed = int(rng.integers(0, MAX_SEED))
79
 
80
  ratio = validate_aspect_ratio(aspect_ratio)
81
  if ratio is None:
 
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
+ from numpy.random import PCG64DXSM, Generator, SeedSequence
8
  from typing import Tuple, Any
9
 
10
  dtype: torch.dtype = torch.bfloat16
11
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
12
+ MAX_SEED = np.iinfo(np.int64).max
 
13
 
14
  pipe = DiffusionPipeline.from_pretrained("shuttleai/shuttle-3-diffusion", torch_dtype=dtype).to(device)
15
  # Enable VAE tiling
 
60
  case _:
61
  return None
62
 
63
+ # Replace the single rng instance with a function that creates a fresh generator each time
64
+ def get_random_seed() -> int:
65
+ # Create a new generator with a random seed each time
66
+ ss = SeedSequence()
67
+ rng = Generator(PCG64DXSM(ss))
68
+ return int(rng.integers(0, MAX_SEED))
69
+
70
  @spaces.GPU()
71
  def infer(
72
  prompt: str,
 
81
  FULL_PROMPT = f"{STYLE_PROMPT} {prompt}"
82
 
83
  if randomize_seed:
84
+ seed = get_random_seed()
85
 
86
  ratio = validate_aspect_ratio(aspect_ratio)
87
  if ratio is None: