Spaces:
Sleeping
Sleeping
import numpy as np | |
import PIL.Image | |
import torch | |
from diffusers import LCMScheduler, AutoPipelineForText2Image | |
from fastapi import FastAPI | |
import uvicorn | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import RedirectResponse, StreamingResponse | |
import io | |
import os | |
from pathlib import Path | |
from db import Database | |
import uuid | |
import logging | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from asyncio import Lock | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
SPACE_ID = os.environ.get("SPACE_ID", "") | |
DEV = os.environ.get("DEV", "0") == "1" | |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache") | |
IMGS_PATH = DB_PATH / "imgs" | |
DB_PATH.mkdir(exist_ok=True, parents=True) | |
IMGS_PATH.mkdir(exist_ok=True, parents=True) | |
database = Database(DB_PATH) | |
generate_lock = Lock() | |
model_id = "segmind/Segmind-Vega" | |
adapter_id = "segmind/Segmind-VegaRT" | |
dtype = torch.bfloat16 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
pipe = AutoPipelineForText2Image.from_pretrained( | |
model_id, torch_dtype=torch.float16, variant="fp16" | |
) | |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
pipe.to("cuda") | |
pipe.load_lora_weights(adapter_id) | |
pipe.fuse_lora() | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
seed: int = 0, | |
) -> PIL.Image.Image: | |
generator = torch.Generator().manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
generator=generator, | |
num_inference_steps=4, | |
guidance_scale=0, | |
).images[0] | |
return image | |
app = FastAPI() | |
origins = [ | |
"https://huggingface.co", | |
"http://huggingface.co", | |
"https://huggingface.co/", | |
"http://huggingface.co/", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def validate_origin(request: Request, call_next): | |
if DEV: | |
return await call_next(request) | |
if request.headers.get("referer") not in origins: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
return await call_next(request) | |
async def generate_image( | |
prompt: str, negative_prompt: str = "", seed: int = 2134213213 | |
): | |
cached_img = database.check(prompt, negative_prompt, seed) | |
if cached_img: | |
logging.info(f"Image found in cache: {cached_img[0]}") | |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg") | |
logging.info(f"Image not found in cache, generating new image") | |
async with generate_lock: | |
pil_image = generate(prompt, negative_prompt, seed) | |
img_id = str(uuid.uuid4()) | |
img_path = IMGS_PATH / f"{img_id}.jpg" | |
pil_image.save(img_path) | |
img_io = io.BytesIO() | |
pil_image.save(img_io, "JPEG") | |
img_io.seek(0) | |
database.insert(prompt, negative_prompt, str(img_path), seed) | |
return StreamingResponse(img_io, media_type="image/jpeg") | |
async def main(): | |
# redirect to https://huggingface.co/spaces/multimodalart/stable-cascade | |
return RedirectResponse( | |
"https://multimodalart-stable-cascade.hf.space/?__theme=system" | |
) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |