ginipick commited on
Commit
cf46674
โ€ข
1 Parent(s): 30794f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -12,10 +12,13 @@ import torch
12
  from diffusers import FluxPipeline
13
  from PIL import Image
14
 
 
 
 
 
15
 
16
  # Setup and initialization code
17
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
18
- # Use PERSISTENT_DIR environment variable for Spaces
19
  PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
20
  gallery_path = path.join(PERSISTENT_DIR, "gallery")
21
 
@@ -43,16 +46,27 @@ class timer:
43
  if not path.exists(cache_path):
44
  os.makedirs(cache_path, exist_ok=True)
45
 
46
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
47
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  pipe.fuse_lora(lora_scale=0.125)
49
  pipe.to(device="cuda", dtype=torch.bfloat16)
50
 
51
-
52
  def save_image(image):
53
  """Save the generated image and return the path"""
54
  try:
55
- # Ensure gallery directory exists
56
  if not os.path.exists(gallery_path):
57
  try:
58
  os.makedirs(gallery_path, exist_ok=True)
@@ -60,7 +74,6 @@ def save_image(image):
60
  print(f"Failed to create gallery directory: {str(e)}")
61
  return None
62
 
63
- # Generate unique filename with timestamp and random suffix
64
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
65
  random_suffix = os.urandom(4).hex()
66
  filename = f"generated_{timestamp}_{random_suffix}.png"
@@ -86,24 +99,19 @@ def save_image(image):
86
  print(f"Error in save_image: {str(e)}")
87
  return None
88
 
89
-
90
  def load_gallery():
91
  """Load all images from the gallery directory"""
92
  try:
93
- # Ensure gallery directory exists
94
  os.makedirs(gallery_path, exist_ok=True)
95
 
96
- # Get all image files and sort by modification time
97
  image_files = []
98
  for f in os.listdir(gallery_path):
99
  if f.lower().endswith(('.png', '.jpg', '.jpeg')):
100
  full_path = os.path.join(gallery_path, f)
101
  image_files.append((full_path, os.path.getmtime(full_path)))
102
 
103
- # Sort by modification time (newest first)
104
  image_files.sort(key=lambda x: x[1], reverse=True)
105
 
106
- # Return only the file paths
107
  return [f[0] for f in image_files]
108
  except Exception as e:
109
  print(f"Error loading gallery: {str(e)}")
@@ -111,7 +119,6 @@ def load_gallery():
111
 
112
  # Create Gradio interface
113
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
114
-
115
  with gr.Row():
116
  with gr.Column(scale=3):
117
  prompt = gr.Textbox(
@@ -168,18 +175,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
168
  "โœจ Generate Image",
169
  elem_classes=["generate-btn"]
170
  )
171
-
172
-
173
 
174
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
175
- # Current generated image
176
  output = gr.Image(
177
  label="Generated Image",
178
  elem_id="output-image",
179
  elem_classes=["output-image", "fixed-width"]
180
  )
181
 
182
- # Gallery of generated images
183
  gallery = gr.Gallery(
184
  label="Generated Images Gallery",
185
  show_label=True,
@@ -191,7 +194,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
191
  elem_classes=["gallery-container", "fixed-width"]
192
  )
193
 
194
- # Load existing gallery images on startup
195
  gallery.value = load_gallery()
196
 
197
  @spaces.GPU
@@ -209,18 +211,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
209
  max_sequence_length=256
210
  ).images[0]
211
 
212
- # Save the generated image
213
  saved_path = save_image(generated_image)
214
  if saved_path is None:
215
  print("Warning: Failed to save generated image")
216
 
217
- # Return both the generated image and updated gallery
218
  return generated_image, load_gallery()
219
  except Exception as e:
220
  print(f"Error in image generation: {str(e)}")
221
  return None, load_gallery()
222
 
223
- # Connect the generation button to both the image output and gallery update
224
  def update_seed():
225
  return get_random_seed()
226
 
@@ -230,13 +229,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
230
  outputs=[output, gallery]
231
  )
232
 
233
- # Add randomize seed button functionality
234
  randomize_seed.click(
235
  update_seed,
236
  outputs=[seed]
237
  )
238
 
239
- # Also randomize seed after each generation
240
  generate_btn.click(
241
  update_seed,
242
  outputs=[seed]
 
12
  from diffusers import FluxPipeline
13
  from PIL import Image
14
 
15
+ # Hugging Face ํ† ํฐ ์„ค์ •
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ if HF_TOKEN is None:
18
+ raise ValueError("HF_TOKEN environment variable is not set")
19
 
20
  # Setup and initialization code
21
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
 
22
  PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
23
  gallery_path = path.join(PERSISTENT_DIR, "gallery")
24
 
 
46
  if not path.exists(cache_path):
47
  os.makedirs(cache_path, exist_ok=True)
48
 
49
+ # ์ธ์ฆ๋œ ๋ชจ๋ธ ๋กœ๋“œ
50
+ pipe = FluxPipeline.from_pretrained(
51
+ "black-forest-labs/FLUX.1-dev",
52
+ torch_dtype=torch.bfloat16,
53
+ use_auth_token=HF_TOKEN
54
+ )
55
+
56
+ # Hyper-SD LoRA ๋กœ๋“œ (์ธ์ฆ ํฌํ•จ)
57
+ pipe.load_lora_weights(
58
+ hf_hub_download(
59
+ "ByteDance/Hyper-SD",
60
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
61
+ use_auth_token=HF_TOKEN
62
+ )
63
+ )
64
  pipe.fuse_lora(lora_scale=0.125)
65
  pipe.to(device="cuda", dtype=torch.bfloat16)
66
 
 
67
  def save_image(image):
68
  """Save the generated image and return the path"""
69
  try:
 
70
  if not os.path.exists(gallery_path):
71
  try:
72
  os.makedirs(gallery_path, exist_ok=True)
 
74
  print(f"Failed to create gallery directory: {str(e)}")
75
  return None
76
 
 
77
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
78
  random_suffix = os.urandom(4).hex()
79
  filename = f"generated_{timestamp}_{random_suffix}.png"
 
99
  print(f"Error in save_image: {str(e)}")
100
  return None
101
 
 
102
  def load_gallery():
103
  """Load all images from the gallery directory"""
104
  try:
 
105
  os.makedirs(gallery_path, exist_ok=True)
106
 
 
107
  image_files = []
108
  for f in os.listdir(gallery_path):
109
  if f.lower().endswith(('.png', '.jpg', '.jpeg')):
110
  full_path = os.path.join(gallery_path, f)
111
  image_files.append((full_path, os.path.getmtime(full_path)))
112
 
 
113
  image_files.sort(key=lambda x: x[1], reverse=True)
114
 
 
115
  return [f[0] for f in image_files]
116
  except Exception as e:
117
  print(f"Error loading gallery: {str(e)}")
 
119
 
120
  # Create Gradio interface
121
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
122
  with gr.Row():
123
  with gr.Column(scale=3):
124
  prompt = gr.Textbox(
 
175
  "โœจ Generate Image",
176
  elem_classes=["generate-btn"]
177
  )
 
 
178
 
179
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
 
180
  output = gr.Image(
181
  label="Generated Image",
182
  elem_id="output-image",
183
  elem_classes=["output-image", "fixed-width"]
184
  )
185
 
 
186
  gallery = gr.Gallery(
187
  label="Generated Images Gallery",
188
  show_label=True,
 
194
  elem_classes=["gallery-container", "fixed-width"]
195
  )
196
 
 
197
  gallery.value = load_gallery()
198
 
199
  @spaces.GPU
 
211
  max_sequence_length=256
212
  ).images[0]
213
 
 
214
  saved_path = save_image(generated_image)
215
  if saved_path is None:
216
  print("Warning: Failed to save generated image")
217
 
 
218
  return generated_image, load_gallery()
219
  except Exception as e:
220
  print(f"Error in image generation: {str(e)}")
221
  return None, load_gallery()
222
 
 
223
  def update_seed():
224
  return get_random_seed()
225
 
 
229
  outputs=[output, gallery]
230
  )
231
 
 
232
  randomize_seed.click(
233
  update_seed,
234
  outputs=[seed]
235
  )
236
 
 
237
  generate_btn.click(
238
  update_seed,
239
  outputs=[seed]