radames's picture
no
22e6fb5
raw
history blame
7.52 kB
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
from typing import List
from diffusers.utils import numpy_to_pil
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str = ""
seed: int = 0
app = FastAPI()
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def main():
# redirect to https://huggingface.co/spaces/multimodalart/stable-cascade
return RedirectResponse("https://multimodalart-stable-cascade.hf.space/?__theme=system")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# MAX_SEED = np.iinfo(np.int32).max
# USE_TORCH_COMPILE = False
# dtype = torch.bfloat16
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if torch.cuda.is_available():
# prior_pipeline = StableCascadePriorPipeline.from_pretrained(
# "stabilityai/stable-cascade-prior", torch_dtype=dtype) # .to(device)
# decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
# "stabilityai/stable-cascade", torch_dtype=dtype) # .to(device)
# prior_pipeline.to(device)
# decoder_pipeline.to(device)
# if USE_TORCH_COMPILE:
# prior_pipeline.prior = torch.compile(
# prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
# decoder_pipeline.decoder = torch.compile(
# decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
# else:
# prior_pipeline = None
# decoder_pipeline = None
# def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
# if randomize_seed:
# seed = random.randint(0, MAX_SEED)
# return seed
# def generate(
# prompt: str,
# negative_prompt: str = "",
# seed: int = 0,
# width: int = 1024,
# height: int = 1024,
# prior_num_inference_steps: int = 30,
# # prior_timesteps: List[float] = None,
# prior_guidance_scale: float = 4.0,
# decoder_num_inference_steps: int = 12,
# # decoder_timesteps: List[float] = None,
# decoder_guidance_scale: float = 0.0,
# num_images_per_prompt: int = 2,
# progress=gr.Progress(track_tqdm=True),
# ) -> PIL.Image.Image:
# generator = torch.Generator().manual_seed(seed)
# prior_output = prior_pipeline(
# prompt=prompt,
# height=height,
# width=width,
# num_inference_steps=prior_num_inference_steps,
# timesteps=DEFAULT_STAGE_C_TIMESTEPS,
# negative_prompt=negative_prompt,
# guidance_scale=prior_guidance_scale,
# num_images_per_prompt=num_images_per_prompt,
# generator=generator,
# )
# decoder_output = decoder_pipeline(
# image_embeddings=prior_output.image_embeddings,
# prompt=prompt,
# num_inference_steps=decoder_num_inference_steps,
# # timesteps=decoder_timesteps,
# guidance_scale=decoder_guidance_scale,
# negative_prompt=negative_prompt,
# generator=generator,
# output_type="pil",
# ).images
# return decoder_output[0]
# examples = [
# "An astronaut riding a green horse",
# "A mecha robot in a favela by Tarsila do Amaral",
# "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
# "A delicious feijoada ramen dish"
# ]
# with gr.Blocks() as demo:
# gr.Markdown(DESCRIPTION)
# gr.DuplicateButton(
# value="Duplicate Space for private use",
# elem_id="duplicate-button",
# visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
# )
# with gr.Group():
# with gr.Row():
# prompt = gr.Text(
# label="Prompt",
# show_label=False,
# max_lines=1,
# placeholder="Enter your prompt",
# container=False,
# )
# run_button = gr.Button("Run", scale=0)
# result = gr.Image(label="Result", show_label=False)
# with gr.Accordion("Advanced options", open=False):
# negative_prompt = gr.Text(
# label="Negative prompt",
# max_lines=1,
# placeholder="Enter a Negative Prompt",
# )
# seed = gr.Slider(
# label="Seed",
# minimum=0,
# maximum=MAX_SEED,
# step=1,
# value=0,
# )
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
# with gr.Row():
# width = gr.Slider(
# label="Width",
# minimum=1024,
# maximum=1536,
# step=512,
# value=1024,
# )
# height = gr.Slider(
# label="Height",
# minimum=1024,
# maximum=1536,
# step=512,
# value=1024,
# )
# num_images_per_prompt = gr.Slider(
# label="Number of Images",
# minimum=1,
# maximum=2,
# step=1,
# value=1,
# )
# with gr.Row():
# prior_guidance_scale = gr.Slider(
# label="Prior Guidance Scale",
# minimum=0,
# maximum=20,
# step=0.1,
# value=4.0,
# )
# prior_num_inference_steps = gr.Slider(
# label="Prior Inference Steps",
# minimum=10,
# maximum=30,
# step=1,
# value=20,
# )
# decoder_guidance_scale = gr.Slider(
# label="Decoder Guidance Scale",
# minimum=0,
# maximum=0,
# step=0.1,
# value=0.0,
# )
# decoder_num_inference_steps = gr.Slider(
# label="Decoder Inference Steps",
# minimum=4,
# maximum=12,
# step=1,
# value=10,
# )
# gr.Examples(
# examples=examples,
# inputs=prompt,
# outputs=result,
# fn=generate,
# cache_examples=False,
# )
# inputs = [
# prompt,
# negative_prompt,
# seed,
# width,
# height,
# prior_num_inference_steps,
# # prior_timesteps,
# prior_guidance_scale,
# decoder_num_inference_steps,
# # decoder_timesteps,
# decoder_guidance_scale,
# num_images_per_prompt,
# ]
# gr.on(
# triggers=[prompt.submit, negative_prompt.submit, run_button.click],
# fn=randomize_seed_fn,
# inputs=[seed, randomize_seed],
# outputs=seed,
# queue=False,
# api_name=False,
# ).then(
# fn=generate,
# inputs=inputs,
# outputs=result,
# api_name="run",
# )
# if __name__ == "__main__":
# demo.queue(max_size=20).launch()