aimersion commited on
Commit
80a2167
1 Parent(s): 46b8903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -3,18 +3,39 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
 
 
6
  from diffusers import DiffusionPipeline
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
16
  @spaces.GPU()
17
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, guidance_scale=7.5, progress=gr.Progress(track_tqdm=True)):
 
 
18
  if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
19
  raise ValueError("Image size exceeds the maximum allowed dimensions.")
20
 
@@ -24,23 +45,27 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
24
 
25
  try:
26
  image = pipe(
27
- prompt=prompt,
28
- width=width,
29
- height=height,
30
- num_inference_steps=num_inference_steps,
31
- generator=generator,
32
- guidance_scale=guidance_scale
33
- ).images[0]
34
  except Exception as e:
 
35
  return None, seed, f"Error: {str(e)}"
36
 
 
 
 
 
37
  return image, seed, None
38
 
39
  examples = [
40
  "a tiny astronaut hatching from an egg on the moon",
41
  "a cat holding a sign that says hello world",
42
  "an anime illustration of a wiener schnitzel",
43
- # Add more diverse examples
44
  ]
45
 
46
  css = """
@@ -55,7 +80,7 @@ with gr.Blocks(css=css) as demo:
55
  with gr.Column(elem_id="col-container"):
56
  gr.Markdown(f"""# Custom Image Creator
57
  12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
58
- [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
59
  """)
60
 
61
  with gr.Row():
 
3
  import random
4
  import spaces
5
  import torch
6
+ import os
7
+ import time
8
  from diffusers import DiffusionPipeline
9
+ from huggingface_hub import login
10
 
11
+ # Ensure sentencepiece is installed in your environment
12
+ try:
13
+ import sentencepiece
14
+ except ImportError:
15
+ raise ImportError("The 'sentencepiece' library is required but not installed. Please add it to your environment.")
16
+
17
+ # Access the API token securely from Hugging Face Secrets
18
+ hf_api_token = os.getenv("HF_API_TOKEN")
19
+
20
+ if hf_api_token:
21
+ login(token=hf_api_token)
22
+ else:
23
+ raise ValueError("Hugging Face API token not found in secrets.")
24
+
25
+ # Set the device and dtype
26
  dtype = torch.bfloat16
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+ # Load the diffusion pipeline from the gated repository
30
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
  MAX_IMAGE_SIZE = 2048
34
 
35
  @spaces.GPU()
36
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, guidance_scale=7.5, progress=gr.Progress(track_tqdm=True)):
37
+ start_time = time.time()
38
+
39
  if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
40
  raise ValueError("Image size exceeds the maximum allowed dimensions.")
41
 
 
45
 
46
  try:
47
  image = pipe(
48
+ prompt=prompt,
49
+ width=width,
50
+ height=height,
51
+ num_inference_steps=num_inference_steps,
52
+ generator=generator,
53
+ guidance_scale=guidance_scale
54
+ ).images[0]
55
  except Exception as e:
56
+ print(f"Error generating image: {e}")
57
  return None, seed, f"Error: {str(e)}"
58
 
59
+ # Check if it took too long
60
+ if time.time() - start_time > 60: # 60 seconds timeout
61
+ return None, seed, "Image generation took too long and was cancelled."
62
+
63
  return image, seed, None
64
 
65
  examples = [
66
  "a tiny astronaut hatching from an egg on the moon",
67
  "a cat holding a sign that says hello world",
68
  "an anime illustration of a wiener schnitzel",
 
69
  ]
70
 
71
  css = """
 
80
  with gr.Column(elem_id="col-container"):
81
  gr.Markdown(f"""# Custom Image Creator
82
  12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
83
+ [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1)]
84
  """)
85
 
86
  with gr.Row():