Spaces:
Running
Running
import gradio as gr | |
import requests | |
import io | |
import random | |
import os | |
from PIL import Image | |
from deep_translator import GoogleTranslator | |
from langdetect import detect | |
import cv2 | |
import torch | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
os.system("pip freeze") | |
# download weights | |
if not os.path.exists('realesr-general-x4v3.pth'): | |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .") | |
if not os.path.exists('GFPGANv1.2.pth'): | |
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .") | |
if not os.path.exists('GFPGANv1.3.pth'): | |
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .") | |
if not os.path.exists('GFPGANv1.4.pth'): | |
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .") | |
if not os.path.exists('RestoreFormer.pth'): | |
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .") | |
if not os.path.exists('CodeFormer.pth'): | |
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .") | |
# background enhancer with RealESRGAN | |
model_us = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
model_us_path = 'realesr-general-x4v3.pth' | |
half = True if torch.cuda.is_available() else False | |
upsampler = RealESRGANer(scale=4, model_path=model_us_path, model=model_us, tile=0, tile_pad=10, pre_pad=0, half=half) | |
os.makedirs('output', exist_ok=True) | |
API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl" | |
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
models_list = ["AbsoluteReality 1.8.1", "DALL-E 3 XL", "Playground 2", "Openjourney 4", "Lyriel 1.6", "Animagine XL 2.0", "Counterfeit 2.5", "Realistic Vision 5.1", "Incursios 1.6", "Anime Detailer XL LoRA", "epiCRealism", "PixelArt XL", "NewReality XL"] | |
def query(prompt, model, is_negative=False, steps=20, cfg_scale=7, seed=None): | |
language = detect(prompt) | |
if language == 'ru': | |
prompt = GoogleTranslator(source='ru', target='en').translate(prompt) | |
print(f'\033[1mГенерация:\033[0m {prompt}') | |
if model == 'DALL-E 3 XL': | |
API_URL = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl" | |
if model == 'Playground 2': | |
API_URL = "https://api-inference.huggingface.co/models/playgroundai/playground-v2-1024px-aesthetic" | |
if model == 'Openjourney 4': | |
API_URL = "https://api-inference.huggingface.co/models/prompthero/openjourney-v4" | |
if model == 'AbsoluteReality 1.8.1': | |
API_URL = "https://api-inference.huggingface.co/models/digiplay/AbsoluteReality_v1.8.1" | |
if model == 'Lyriel 1.6': | |
API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/lyrielv16" | |
if model == 'Animagine XL 2.0': | |
API_URL = "https://api-inference.huggingface.co/models/Linaqruf/animagine-xl-2.0" | |
if model == 'Counterfeit 2.5': | |
API_URL = "https://api-inference.huggingface.co/models/gsdf/Counterfeit-V2.5" | |
if model == 'Realistic Vision 5.1': | |
API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/realistic-vision-v51" | |
if model == 'Incursios 1.6': | |
API_URL = "https://api-inference.huggingface.co/models/digiplay/incursiosMemeDiffusion_v1.6" | |
if model == 'Anime Detailer XL LoRA': | |
API_URL = "https://api-inference.huggingface.co/models/Linaqruf/anime-detailer-xl-lora" | |
if model == 'epiCRealism': | |
API_URL = "https://api-inference.huggingface.co/models/emilianJR/epiCRealism" | |
if model == 'PixelArt XL': | |
API_URL = "https://api-inference.huggingface.co/models/nerijs/pixel-art-xl" | |
if model == 'NewReality XL': | |
API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/newrealityxl-global-nsfw" | |
payload = { | |
"inputs": prompt, | |
"is_negative": is_negative, | |
"steps": steps, | |
"cfg_scale": cfg_scale, | |
"seed": seed if seed is not None else random.randint(-1, 2147483647) | |
} | |
image_bytes = requests.post(API_URL, headers=headers, json=payload).content | |
image = Image.open(io.BytesIO(image_bytes)) | |
return image | |
def up(img, version, scale, weight): | |
weight /= 100 | |
print(img, version, scale, weight) | |
try: | |
extension = os.path.splitext(os.path.basename(str(img)))[1] | |
img = cv2.imread(img, cv2.IMREAD_UNCHANGED) | |
if len(img.shape) == 3 and img.shape[2] == 4: | |
img_mode = 'RGBA' | |
elif len(img.shape) == 2: # for gray inputs | |
img_mode = None | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
else: | |
img_mode = None | |
if version == 'v1.2': | |
face_enhancer = GFPGANer( | |
model_us_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
elif version == 'v1.3': | |
face_enhancer = GFPGANer( | |
model_us_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
elif version == 'v1.4': | |
face_enhancer = GFPGANer( | |
model_us_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
elif version == 'RestoreFormer': | |
face_enhancer = GFPGANer( | |
model_us_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
elif version == 'CodeFormer': | |
face_enhancer = GFPGANer( | |
model_us_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
try: | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) | |
except RuntimeError as error: | |
print('Error', error) | |
try: | |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
h, w = img.shape[0:2] | |
output = cv2.resize(output, (int(w * scale), int(h * scale)), interpolation=interpolation) | |
except Exception as error: | |
print('wrong scale input.', error) | |
if img_mode == 'RGBA': # RGBA images should be saved in png format | |
extension = 'png' | |
else: | |
extension = 'jpg' | |
save_path = f'output/out.{extension}' | |
cv2.imwrite(save_path, output) | |
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
return output | |
except Exception as error: | |
print('global exception', error) | |
return None | |
css = """ | |
footer {visibility: hidden !important;} | |
""" | |
with gr.Blocks(css=css) as dalle: | |
with gr.Tab("Базовые настройки"): | |
with gr.Row(): | |
with gr.Column(elem_id="prompt-container"): | |
text_prompt = gr.Textbox(label="Prompt", placeholder="Описание изображения", lines=3, elem_id="prompt-text-input") | |
model = gr.Radio(label="Модель", value="DALL-E 3 XL", choices=models_list) | |
with gr.Tab("Расширенные настройки"): | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Чего не должно быть на изображении", value="[deformed | disfigured], poorly drawn, [bad : wrong] anatomy, [extra | missing | floating | disconnected] limb, (mutated hands and fingers), blurry, text, fuzziness", lines=3, elem_id="negative-prompt-text-input") | |
with gr.Tab("Настройки апскейлинга"): | |
up_1 = gr.Radio(choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], value='v1.4', label='Версия'), | |
up_2 = gr.Slider(label="Коэффициент масштабирования", value=2, minimum=2, maximum=6), | |
up_3 = gr.Slider(0, 100, label='Weight, только для CodeFormer. 0 для лучшего качества, 100 для лучшей идентичности', value=50) | |
with gr.Row(): | |
text_button = gr.Button("Генерация", variant='primary', elem_id="gen-button") | |
with gr.Row(): | |
image_output = gr.Image(type="pil", label="Изображение", elem_id="gallery") | |
with gr.Row(): | |
up_button = gr.Button("Улучшить изображение", variant='primary', elem_id="gen-button") | |
with gr.Row(): | |
up_output = gr.Image(type="pil", label="Улучшенное изображение", elem_id="gallery"), | |
text_button.click(query, inputs=[text_prompt, model, negative_prompt], outputs=image_output) | |
up_button.click(up, inputs=[image_output, up_1, up_2, up_3], outputs=up_output) | |
dalle.launch(show_api=False) |