Rooni's picture
Update app.py
ea1eaed verified
import os
import random
import sys
import requests
from typing import Sequence, Mapping, Any, Union
import gradio as gr
from deep_translator import GoogleTranslator
from langdetect import detect
from gradio_client import Client, handle_file
from PIL import Image
# Функция для получения случайного API ключа
def get_random_api_key():
keys = os.getenv("KEYS", "").split(",")
if keys and keys[0]: # Проверяем, установлены ли ключи и не пусты ли они
return random.choice(keys).strip()
else:
raise ValueError("API ключи не найдены. Пожалуйста, установите переменную окружения KEYS.")
# Ссылка на файл CSS
css_url = "https://neurixyufi-aihub.static.hf.space/style.css"
# Получение CSS по ссылке
try:
response = requests.get(css_url)
response.raise_for_status()
css = response.text + " h1{text-align:center}"
except requests.exceptions.RequestException as e:
print(f"Ошибка при загрузке CSS: {e}")
css = " h1{text-align:center}"
# Функция для перевода текста на английский
def translate_to_english(prompt):
language = detect(prompt)
if language != 'en':
prompt = GoogleTranslator(source=language, target='en').translate(prompt)
return prompt
# Функция для загрузки изображений в кеш и отправки ссылки на API
def upload_image_to_hf_cache(image):
if isinstance(image, dict) and 'url' in image:
return image['url']
elif isinstance(image, str):
return image
else:
raise ValueError("Неподдерживаемый формат изображения")
# Функция для генерации изображения через API
def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
"""Основная функция генерации изображения."""
prompt = translate_to_english(prompt) if prompt else ""
structure_image_url = upload_image_to_hf_cache(structure_image)
style_image_url = upload_image_to_hf_cache(style_image)
client = Client("multimodalart/flux-style-shaping", hf_token=get_random_api_key())
result = client.predict(
prompt=prompt,
structure_image=handle_file(structure_image_url),
style_image=handle_file(style_image_url),
depth_strength=depth_strength,
style_strength=style_strength,
api_name="/generate_image"
)
if isinstance(result, str) and os.path.exists(result):
output_image = Image.open(result)
elif isinstance(result, bytes):
output_image = Image.open(BytesIO(result))
else:
raise ValueError(f"Неожиданный тип результата API: {type(result)}")
return output_image
# Примеры для Gradio
examples = [
["", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/mona.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/receita-tacos.webp", 15, 0.6],
["Девочка смотрит на дом, который горит", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/disaster_girl.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/abaporu.jpg", 15, 0.15],
["Город Истанбул с высоты птичьего полёта", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/natasha.png", "https://huggingface.co/spaces/multimodalart/flux-style-shaping/resolve/main/istambul.jpg", 15, 0.5],
]
output_image = gr.Image(label="Сгенерированное изображение", show_share_button=False)
with gr.Blocks(css=css) as app:
gr.Markdown("# Структуратор")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Запрос", placeholder="Введите ваш запрос здесь...")
with gr.Row():
with gr.Group():
structure_image = gr.Image(label="Изображение структуры", type="filepath")
depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Сила глубины")
with gr.Group():
style_image = gr.Image(label="Изображение стиля", type="filepath")
style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Сила стиля")
generate_btn = gr.Button("Создать", variant='primary')
gr.Examples(
examples=examples,
inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
outputs=[output_image],
fn=generate_image,
label="Примеры",
cache_examples=False,
)
with gr.Column():
output_image.render()
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
outputs=[output_image],
concurrency_limit=250
)
if __name__ == "__main__":
app.launch(show_api=False, share=False)