fast-stable-diffusion / inference.py
zenafey's picture
Update inference.py
6e4bf3d
from prodiapy import Custom
from prodiapy.util import load
from PIL import Image
from threading import Thread
from utils import image_to_base64
import gradio as gr
import gradio_user_history as gr_user_history
import os
pipe = Custom(os.getenv("PRODIA_API_KEY"))
def txt2img(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed, batch_count, profile: gr.OAuthProfile | None):
total_images = []
threads = []
def generate_one_image():
result = pipe.create(
"/sd/generate",
prompt=prompt,
negative_prompt=negative_prompt,
model=model,
steps=steps,
cfg_scale=cfg_scale,
sampler=sampler,
width=width,
height=height,
seed=seed
)
job = pipe.wait_for(result)
total_images.append(job['imageUrl'])
for x in range(batch_count):
t = Thread(target=generate_one_image)
threads.append(t)
t.start()
for t in threads:
t.join()
for image in total_images:
gr_user_history.save_image(label=prompt, image=Image.open(load(image)), profile=profile)
return gr.update(value=total_images, preview=False)
def img2img(input_image, denoising, prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed,
batch_count):
if input_image is None:
return
total_images = []
threads = []
def generate_one_image():
result = pipe.create(
"/sd/transform",
imageData=image_to_base64(input_image),
denoising_strength=denoising,
prompt=prompt,
negative_prompt=negative_prompt,
model=model,
steps=steps,
cfg_scale=cfg_scale,
sampler=sampler,
width=width,
height=height,
seed=seed
)
job = pipe.wait_for(result)
total_images.append(job['imageUrl'])
for x in range(batch_count):
t = Thread(target=generate_one_image)
threads.append(t)
t.start()
for t in threads:
t.join()
return gr.update(value=total_images, preview=False)
def upscale(image, scale, profile: gr.OAuthProfile | None):
if image is None:
return
job = pipe.create(
'/upscale',
imageData=image_to_base64(image),
resize=scale
)
image = pipe.wait_for(job)['imageUrl']
gr_user_history.save_image(label=f'upscale by {scale}', image=Image.open(load(image)), profile=profile)
return image