OmniGen / app.py
Shitao's picture
Update app.py
c38e273 verified
raw
history blame
5.77 kB
import gradio as gr
from PIL import Image
import os
import spaces
from OmniGen import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1"
)
@spaces.GPU(duratio=120)
# 示例处理函数:生成图像
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
input_images = [img1, img2, img3]
# 去除 None
input_images = [img for img in input_images if img is not None]
if len(input_images) == 0:
input_images = None
output = pipe(
prompt=text,
input_images=input_images,
height=height,
width=width,
guidance_scale=guidance_scale,
img_guidance_scale=1.6,
num_inference_steps=inference_steps,
separate_cfg_infer=True,
use_kv_cache=False,
seed=seed,
)
img = output[0]
return img
# def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
# input_images = []
# if img1:
# input_images.append(Image.open(img1))
# if img2:
# input_images.append(Image.open(img2))
# if img3:
# input_images.append(Image.open(img3))
# return input_images[0] if input_images else None
def get_example():
case = [
[
"A woman holds a bouquet of flowers and faces the camera.",
None,
None,
None,
1024,
1024,
3.0,
50,
42,
],
[
"A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in <img><|image_1|></img>.",
"./imgs/test_cases/liuyifei.png",
None,
None,
1024,
1024,
3.0,
50,
42,
],
[
"Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.",
"./imgs/test_cases/img1.jpg",
"./imgs/test_cases/img2.jpg",
"./imgs/test_cases/img3.jpg",
1024,
1024,
3.0,
50,
42,
],
]
return case
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("# OmniGen: Unified Image Generation [paper](https://arxiv.org/abs/2409.11340) [code](https://github.com/VectorSpaceLab/OmniGen)")
with gr.Row():
with gr.Column():
# 文本输入框
prompt_input = gr.Textbox(
label="Enter your prompt, use <img><|image_i|></img> tokens for images", placeholder="Type your prompt here..."
)
with gr.Row(equal_height=True):
# 图片上传框
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
# 高度和宽度滑块
height_input = gr.Slider(
label="Height", minimum=256, maximum=2048, value=1024, step=16
)
width_input = gr.Slider(
label="Width", minimum=256, maximum=2048, value=1024, step=16
)
# 引导尺度输入
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=1, maximum=100, value=50, step=1
)
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
)
# 生成按钮
generate_button = gr.Button("Generate Image")
with gr.Column():
# 输出图像框
output_image = gr.Image(label="Output Image")
# 按钮点击事件
generate_button.click(
generate_image,
inputs=[
prompt_input,
image_input_1,
image_input_2,
image_input_3,
height_input,
width_input,
guidance_scale_input,
num_inference_steps,
seed_input,
],
outputs=output_image,
)
gr.Examples(
examples=get_example(),
fn=run_for_examples,
inputs=[
prompt_input,
image_input_1,
image_input_2,
image_input_3,
height_input,
width_input,
guidance_scale_input,
num_inference_steps,
seed_input,
],
outputs=output_image,
)
# 启动应用
demo.launch()