ahmedmbutt commited on
Commit
5cf7dee
1 Parent(s): 089f816

Update main.py: Refactor code for CPU compatibility

Browse files
Files changed (1) hide show
  1. main.py +7 -4
main.py CHANGED
@@ -5,12 +5,12 @@ from starlette.middleware.cors import CORSMiddleware
5
 
6
  from PIL import Image
7
  from io import BytesIO
8
- from transformers import CLIPFeatureExtractor
9
  from diffusers import (
10
  AutoPipelineForText2Image,
11
  AutoPipelineForImage2Image,
12
  AutoPipelineForInpainting,
13
  )
 
14
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
15
 
16
 
@@ -18,11 +18,11 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
18
  async def lifespan(app: FastAPI):
19
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
20
  "openai/clip-vit-base-patch32"
21
- )
22
 
23
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
24
  "CompVis/stable-diffusion-safety-checker"
25
- )
26
 
27
  text2img = AutoPipelineForText2Image.from_pretrained(
28
  "stabilityai/sd-turbo",
@@ -39,6 +39,7 @@ async def lifespan(app: FastAPI):
39
  del inpaint
40
  del img2img
41
  del text2img
 
42
  del safety_checker
43
  del feature_extractor
44
 
@@ -68,7 +69,9 @@ async def text_to_image(
68
  num_inference_steps: int = Form(1),
69
  ):
70
  results = request.state.text2img(
71
- prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0
 
 
72
  )
73
 
74
  if not results.nsfw_content_detected[0]:
 
5
 
6
  from PIL import Image
7
  from io import BytesIO
 
8
  from diffusers import (
9
  AutoPipelineForText2Image,
10
  AutoPipelineForImage2Image,
11
  AutoPipelineForInpainting,
12
  )
13
+ from transformers import CLIPFeatureExtractor
14
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
15
 
16
 
 
18
  async def lifespan(app: FastAPI):
19
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
20
  "openai/clip-vit-base-patch32"
21
+ ).to("cpu")
22
 
23
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
24
  "CompVis/stable-diffusion-safety-checker"
25
+ ).to("cpu")
26
 
27
  text2img = AutoPipelineForText2Image.from_pretrained(
28
  "stabilityai/sd-turbo",
 
39
  del inpaint
40
  del img2img
41
  del text2img
42
+
43
  del safety_checker
44
  del feature_extractor
45
 
 
69
  num_inference_steps: int = Form(1),
70
  ):
71
  results = request.state.text2img(
72
+ prompt=prompt,
73
+ num_inference_steps=num_inference_steps,
74
+ guidance_scale=0.0,
75
  )
76
 
77
  if not results.nsfw_content_detected[0]: