aimersion commited on
Commit
35d32f6
1 Parent(s): 9583fce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -14,11 +14,11 @@ except ImportError:
14
  raise ImportError("The 'sentencepiece' library is required but not installed. Please add it to your environment.")
15
 
16
  # Set the device and dtype
17
- dtype = torch.bfloat16
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Load the diffusion pipeline without requiring an API token
21
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 2048
@@ -32,7 +32,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
32
 
33
  if randomize_seed:
34
  seed = random.randint(0, MAX_SEED)
35
- generator = torch.Generator().manual_seed(seed)
36
 
37
  try:
38
  image = pipe(
 
14
  raise ImportError("The 'sentencepiece' library is required but not installed. Please add it to your environment.")
15
 
16
  # Set the device and dtype
17
+ dtype = torch.float16 # Change to float16 for better compatibility and performance
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Load the diffusion pipeline without requiring an API token
21
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 2048
 
32
 
33
  if randomize_seed:
34
  seed = random.randint(0, MAX_SEED)
35
+ generator = torch.Generator(device=device).manual_seed(seed) # Ensure generator is on the correct device
36
 
37
  try:
38
  image = pipe(