|
import os |
|
from io import BytesIO |
|
|
|
import gradio as gr |
|
import grpc |
|
from PIL import Image |
|
from cachetools import LRUCache |
|
from gradio.image_utils import crop_scale |
|
|
|
from inference_pb2 import GuideAndRescaleRequest, GuideAndRescaleResponse |
|
from inference_pb2_grpc import GuideAndRescaleServiceStub |
|
|
|
|
|
def get_bytes(img): |
|
if img is None: |
|
return img |
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="JPEG") |
|
return buffered.getvalue() |
|
|
|
|
|
def bytes_to_image(image: bytes) -> Image.Image: |
|
image = Image.open(BytesIO(image)) |
|
return image |
|
|
|
|
|
def edit(editor, source_prompt, target_prompt, config, progress=gr.Progress(track_tqdm=True)): |
|
image = editor['composite'] |
|
|
|
if not image or not source_prompt or not target_prompt: |
|
raise ValueError("Need to upload an image and enter init and edit prompts") |
|
|
|
width, height = image.size |
|
if width != height: |
|
size = min(width, height) |
|
image = crop_scale(image, size, size) |
|
|
|
if image.size != (512, 512): |
|
image = image.resize((512, 512), Image.Resampling.LANCZOS) |
|
|
|
image_bytes = get_bytes(image) |
|
with grpc.insecure_channel(os.environ['SERVER']) as channel: |
|
stub = GuideAndRescaleServiceStub(channel) |
|
|
|
output: GuideAndRescaleResponse = stub.swap( |
|
GuideAndRescaleRequest(image=image_bytes, source_prompt=source_prompt, target_prompt=target_prompt, |
|
config=config, use_cache=True) |
|
) |
|
|
|
output = bytes_to_image(output.image) |
|
return output |
|
|
|
|
|
def get_demo(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Guide-and-Rescale") |
|
gr.Markdown( |
|
'<div style="display: flex; align-items: center; gap: 10px;">' |
|
'<span>Official Guide-and-Rescale Gradio demo:</span>' |
|
'<a href="https://arxiv.org/abs/2409.01322"><img src="https://img.shields.io/badge/arXiv-2409.01322-b31b1b.svg" height=22.5></a>' |
|
'<a href="https://github.com/FusionBrainLab/Guide-and-Rescale"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>' |
|
'<a href="https://colab.research.google.com/drive/1noKOOcDBBL_m5_UqU15jBBqiM8piLZ1O?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>' |
|
'</div>' |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
image = gr.ImageEditor(label="Image that you want to edit", type="pil", layers=False, |
|
interactive=True, crop_size="1:1", eraser=False, brush=False, |
|
image_mode='RGB') |
|
with gr.Row(): |
|
source_prompt = gr.Textbox(label="Init Prompt", info="Describes the content on the original image.") |
|
target_prompt = gr.Textbox(label="Edit Prompt", |
|
info="Describes what is expected in the output image.") |
|
config = gr.Radio(["non-stylisation", "stylisation"], value='non-stylisation', |
|
label="Type of Editing", info="Selects a config for editing.") |
|
with gr.Row(): |
|
btn = gr.Button("Edit image") |
|
with gr.Column(): |
|
with gr.Row(): |
|
output = gr.Image(label="Result: edited image") |
|
|
|
gr.Examples(examples=[["input/1.png", 'A photo of a tiger', 'A photo of a lion', 'non-stylisation'], |
|
["input/zebra.jpeg", 'A photo of a zebra', 'A photo of a white horse', 'non-stylisation'], |
|
["input/13.png", 'A photo', 'Anime style face', 'stylisation']], |
|
inputs=[image, source_prompt, target_prompt, config], |
|
outputs=output) |
|
|
|
image.upload(inputs=[image], outputs=image) |
|
|
|
btn.click(fn=edit, inputs=[image, source_prompt, target_prompt, config], outputs=output) |
|
|
|
gr.Markdown('''To cite the paper by the authors |
|
``` |
|
@article{titov2024guideandrescale |
|
title={Guide-and-Rescale: Self-Guidance Mechanism for Effective Tuning-Free Real Image Editing}, |
|
author={Vadim Titov and Madina Khalmatova and Alexandra Ivanova and Dmitry Vetrov and Aibek Alanov}, |
|
journal={arXiv preprint arXiv:2409.01322}, |
|
year={2024} |
|
} |
|
``` |
|
''') |
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
align_cache = LRUCache(maxsize=10) |
|
demo = get_demo() |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|