Spaces:
Sleeping
Sleeping
ahmedmbutt
commited on
Commit
•
5cf7dee
1
Parent(s):
089f816
Update main.py: Refactor code for CPU compatibility
Browse files
main.py
CHANGED
@@ -5,12 +5,12 @@ from starlette.middleware.cors import CORSMiddleware
|
|
5 |
|
6 |
from PIL import Image
|
7 |
from io import BytesIO
|
8 |
-
from transformers import CLIPFeatureExtractor
|
9 |
from diffusers import (
|
10 |
AutoPipelineForText2Image,
|
11 |
AutoPipelineForImage2Image,
|
12 |
AutoPipelineForInpainting,
|
13 |
)
|
|
|
14 |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
15 |
|
16 |
|
@@ -18,11 +18,11 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
|
18 |
async def lifespan(app: FastAPI):
|
19 |
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
20 |
"openai/clip-vit-base-patch32"
|
21 |
-
)
|
22 |
|
23 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
24 |
"CompVis/stable-diffusion-safety-checker"
|
25 |
-
)
|
26 |
|
27 |
text2img = AutoPipelineForText2Image.from_pretrained(
|
28 |
"stabilityai/sd-turbo",
|
@@ -39,6 +39,7 @@ async def lifespan(app: FastAPI):
|
|
39 |
del inpaint
|
40 |
del img2img
|
41 |
del text2img
|
|
|
42 |
del safety_checker
|
43 |
del feature_extractor
|
44 |
|
@@ -68,7 +69,9 @@ async def text_to_image(
|
|
68 |
num_inference_steps: int = Form(1),
|
69 |
):
|
70 |
results = request.state.text2img(
|
71 |
-
prompt=prompt,
|
|
|
|
|
72 |
)
|
73 |
|
74 |
if not results.nsfw_content_detected[0]:
|
|
|
5 |
|
6 |
from PIL import Image
|
7 |
from io import BytesIO
|
|
|
8 |
from diffusers import (
|
9 |
AutoPipelineForText2Image,
|
10 |
AutoPipelineForImage2Image,
|
11 |
AutoPipelineForInpainting,
|
12 |
)
|
13 |
+
from transformers import CLIPFeatureExtractor
|
14 |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
15 |
|
16 |
|
|
|
18 |
async def lifespan(app: FastAPI):
|
19 |
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
20 |
"openai/clip-vit-base-patch32"
|
21 |
+
).to("cpu")
|
22 |
|
23 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
24 |
"CompVis/stable-diffusion-safety-checker"
|
25 |
+
).to("cpu")
|
26 |
|
27 |
text2img = AutoPipelineForText2Image.from_pretrained(
|
28 |
"stabilityai/sd-turbo",
|
|
|
39 |
del inpaint
|
40 |
del img2img
|
41 |
del text2img
|
42 |
+
|
43 |
del safety_checker
|
44 |
del feature_extractor
|
45 |
|
|
|
69 |
num_inference_steps: int = Form(1),
|
70 |
):
|
71 |
results = request.state.text2img(
|
72 |
+
prompt=prompt,
|
73 |
+
num_inference_steps=num_inference_steps,
|
74 |
+
guidance_scale=0.0,
|
75 |
)
|
76 |
|
77 |
if not results.nsfw_content_detected[0]:
|