radames commited on
Commit
7d54ba7
·
1 Parent(s): 96b49cf
Files changed (1) hide show
  1. app.py +21 -50
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import numpy as np
2
  import PIL.Image
3
  import torch
4
- from typing import List
5
- from diffusers.utils import numpy_to_pil
6
- from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
7
- from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
8
  from fastapi import FastAPI
9
  import uvicorn
10
  from fastapi.middleware.cors import CORSMiddleware
@@ -20,11 +17,8 @@ from fastapi.middleware.cors import CORSMiddleware
20
 
21
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
22
 
23
- MAX_SEED = np.iinfo(np.int32).max
24
- USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
25
  SPACE_ID = os.environ.get("SPACE_ID", "")
26
  DEV = os.environ.get("DEV", "0") == "1"
27
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
28
 
29
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
30
  IMGS_PATH = DB_PATH / "imgs"
@@ -33,64 +27,39 @@ IMGS_PATH.mkdir(exist_ok=True, parents=True)
33
 
34
  database = Database(DB_PATH)
35
 
 
 
 
 
36
  dtype = torch.bfloat16
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
  if torch.cuda.is_available():
39
- prior_pipeline = StableCascadePriorPipeline.from_pretrained(
40
- "stabilityai/stable-cascade-prior", torch_dtype=dtype
41
- ) # .to(device)
42
- decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
43
- "stabilityai/stable-cascade", torch_dtype=dtype
44
- ) # .to(device)
45
- prior_pipeline.to(device)
46
- decoder_pipeline.to(device)
47
-
48
- if USE_TORCH_COMPILE:
49
- prior_pipeline.prior = torch.compile(
50
- prior_pipeline.prior, mode="reduce-overhead", fullgraph=True
51
- )
52
- decoder_pipeline.decoder = torch.compile(
53
- decoder_pipeline.decoder, mode="max-autotune", fullgraph=True
54
- )
55
 
56
 
57
  def generate(
58
  prompt: str,
59
  negative_prompt: str = "",
60
  seed: int = 0,
61
- width: int = 1024,
62
- height: int = 1024,
63
- prior_num_inference_steps: int = 20,
64
- prior_guidance_scale: float = 4.0,
65
- decoder_num_inference_steps: int = 10,
66
- decoder_guidance_scale: float = 0.0,
67
- num_images_per_prompt: int = 2,
68
  ) -> PIL.Image.Image:
69
 
70
  generator = torch.Generator().manual_seed(seed)
71
- prior_output = prior_pipeline(
72
- prompt=prompt,
73
- height=height,
74
- width=width,
75
- num_inference_steps=prior_num_inference_steps,
76
- timesteps=DEFAULT_STAGE_C_TIMESTEPS,
77
- negative_prompt=negative_prompt,
78
- guidance_scale=prior_guidance_scale,
79
- num_images_per_prompt=num_images_per_prompt,
80
- generator=generator,
81
- )
82
- decoder_output = decoder_pipeline(
83
- image_embeddings=prior_output.image_embeddings,
84
  prompt=prompt,
85
- num_inference_steps=decoder_num_inference_steps,
86
- # timesteps=decoder_timesteps,
87
- guidance_scale=decoder_guidance_scale,
88
  negative_prompt=negative_prompt,
89
  generator=generator,
90
- output_type="pil",
91
- ).images
 
92
 
93
- return decoder_output[0]
94
 
95
 
96
  app = FastAPI()
@@ -120,7 +89,9 @@ app.add_middleware(
120
 
121
 
122
  @app.get("/image")
123
- async def generate_image(prompt: str, negative_prompt: str = "", seed: int = 2134213213):
 
 
124
  cached_img = database.check(prompt, negative_prompt, seed)
125
  if cached_img:
126
  logging.info(f"Image found in cache: {cached_img[0]}")
 
1
  import numpy as np
2
  import PIL.Image
3
  import torch
4
+ from diffusers import LCMScheduler, AutoPipelineForText2Image
 
 
 
5
  from fastapi import FastAPI
6
  import uvicorn
7
  from fastapi.middleware.cors import CORSMiddleware
 
17
 
18
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
19
 
 
 
20
  SPACE_ID = os.environ.get("SPACE_ID", "")
21
  DEV = os.environ.get("DEV", "0") == "1"
 
22
 
23
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
24
  IMGS_PATH = DB_PATH / "imgs"
 
27
 
28
  database = Database(DB_PATH)
29
 
30
+
31
+ model_id = "segmind/Segmind-Vega"
32
+ adapter_id = "segmind/Segmind-VegaRT"
33
+
34
  dtype = torch.bfloat16
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
  if torch.cuda.is_available():
37
+ pipe = AutoPipelineForText2Image.from_pretrained(
38
+ model_id, torch_dtype=torch.float16, variant="fp16"
39
+ )
40
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
41
+ pipe.to("cuda")
42
+ pipe.load_lora_weights(adapter_id)
43
+ pipe.fuse_lora()
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def generate(
47
  prompt: str,
48
  negative_prompt: str = "",
49
  seed: int = 0,
 
 
 
 
 
 
 
50
  ) -> PIL.Image.Image:
51
 
52
  generator = torch.Generator().manual_seed(seed)
53
+
54
+ image = pipe(
 
 
 
 
 
 
 
 
 
 
 
55
  prompt=prompt,
 
 
 
56
  negative_prompt=negative_prompt,
57
  generator=generator,
58
+ num_inference_steps=4,
59
+ guidance_scale=0,
60
+ ).images[0]
61
 
62
+ return image
63
 
64
 
65
  app = FastAPI()
 
89
 
90
 
91
  @app.get("/image")
92
+ async def generate_image(
93
+ prompt: str, negative_prompt: str = "", seed: int = 2134213213
94
+ ):
95
  cached_img = database.check(prompt, negative_prompt, seed)
96
  if cached_img:
97
  logging.info(f"Image found in cache: {cached_img[0]}")