abreza commited on
Commit
4e92ab0
·
1 Parent(s): 375ee53

refactor: Update image generation pipeline to use playground v2.5

Browse files
Files changed (2) hide show
  1. launch/image_generation.py +7 -19
  2. launch/utils.py +2 -1
launch/image_generation.py CHANGED
@@ -4,19 +4,16 @@ import gradio as gr
4
  import rembg
5
  import spaces
6
  import torch
7
- from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
8
- from huggingface_hub import hf_hub_download
9
- from safetensors.torch import load_file
10
 
11
  from src.utils.infer_util import (remove_background, resize_foreground)
12
 
13
 
14
- # Load StableDiffusionXL model
15
- base = "stabilityai/stable-diffusion-xl-base-1.0"
16
- repo = "ByteDance/SDXL-Lightning"
17
-
18
- pipe = StableDiffusionXLPipeline.from_pretrained(
19
- base, torch_dtype=torch.float16, variant="fp16").to("cuda")
20
 
21
 
22
  def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
@@ -25,18 +22,9 @@ def generate_prompt(subject, style, color_scheme, angle, lighting_type, addition
25
 
26
  @spaces.GPU
27
  def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
28
- checkpoint = "sdxl_lightning_8step_unet.safetensors"
29
- num_inference_steps = 8
30
-
31
- pipe.scheduler = EulerDiscreteScheduler.from_config(
32
- pipe.scheduler.config, timestep_spacing="trailing")
33
- pipe.unet.load_state_dict(
34
- load_file(hf_hub_download(repo, checkpoint), device="cuda"))
35
-
36
  prompt = generate_prompt(subject, style, color_scheme,
37
  angle, lighting_type, additional_details)
38
- results = pipe(
39
- prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
40
  return results.images[0]
41
 
42
 
 
4
  import rembg
5
  import spaces
6
  import torch
7
+ from diffusers import DiffusionPipeline
 
 
8
 
9
  from src.utils.infer_util import (remove_background, resize_foreground)
10
 
11
 
12
+ pipe = DiffusionPipeline.from_pretrained(
13
+ "playgroundai/playground-v2.5-1024px-aesthetic",
14
+ torch_dtype=torch.float16,
15
+ variant="fp16"
16
+ ).to("cuda")
 
17
 
18
 
19
  def generate_prompt(subject, style, color_scheme, angle, lighting_type, additional_details):
 
22
 
23
  @spaces.GPU
24
  def generate_image(subject, style, color_scheme, angle, lighting_type, additional_details):
 
 
 
 
 
 
 
 
25
  prompt = generate_prompt(subject, style, color_scheme,
26
  angle, lighting_type, additional_details)
27
+ results = pipe(prompt, num_inference_steps=25, guidance_scale=7.5)
 
28
  return results.images[0]
29
 
30
 
launch/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import shutil
3
 
 
4
  def find_cuda():
5
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
6
  if cuda_home and os.path.exists(cuda_home):
@@ -11,4 +12,4 @@ def find_cuda():
11
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
12
  return cuda_path
13
 
14
- return None
 
1
  import os
2
  import shutil
3
 
4
+
5
  def find_cuda():
6
  cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
7
  if cuda_home and os.path.exists(cuda_home):
 
12
  cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
13
  return cuda_path
14
 
15
+ return None