File size: 5,975 Bytes
6af9ebe
cf44c0c
 
 
6af9ebe
 
cf44c0c
 
405f02e
f2f40d5
 
cf44c0c
 
 
550e3c0
aca76a9
 
550e3c0
6af9ebe
 
 
cf44c0c
 
 
 
 
 
 
 
6af9ebe
75ea3d4
6af9ebe
 
7b300c3
550e3c0
0402965
550e3c0
 
 
 
 
 
 
 
0402965
 
 
 
550e3c0
 
7b300c3
 
cf44c0c
 
 
 
 
6a72c67
 
 
 
 
550e3c0
cf44c0c
 
 
d8b1d45
 
 
 
1406a42
d8b1d45
cf44c0c
 
 
 
1e5c4a6
 
 
d8b1d45
 
 
1e5c4a6
 
b2e3d76
cf44c0c
 
 
 
 
d5c8046
7b300c3
d8b1d45
cf44c0c
 
 
 
f2f40d5
 
 
 
 
b2e3d76
f2f40d5
 
 
 
 
 
 
 
b2e3d76
f2f40d5
0402965
550e3c0
 
d8b1d45
 
 
 
cf44c0c
7b300c3
b2e3d76
7b300c3
cf44c0c
 
b2e3d76
cf44c0c
 
 
 
 
 
 
 
 
 
b8b68ce
 
 
 
b2e3d76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf44c0c
be07774
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from flask import Flask, request, send_file, abort
import requests
import io
from PIL import Image
from cachetools import TTLCache, cached
import random
import os
import urllib.parse
import hashlib
from deep_translator import GoogleTranslator
from langdetect import detect

app = Flask(__name__)

# Максимальные значения для ширины и высоты
MAX_WIDTH = 1384
MAX_HEIGHT = 1384

# Кэш на 10 минут
cache = TTLCache(maxsize=100, ttl=600)

# Получаем ключи из переменной окружения
keys = os.getenv("keys", "").split(',')
if not keys:
    raise ValueError("Environment variable 'keys' must be set with a comma-separated list of API keys.")

def get_random_key():
    return random.choice(keys)

def generate_cache_key(prompt, width, height, seed, model_name):
    # Создаем уникальный ключ на основе всех параметров
    return hashlib.md5(f"{prompt}_{width}_{height}_{seed}_{model_name}".encode()).hexdigest()


def scale_dimensions(width, height, max_width, max_height):
    """Масштабирует размеры изображения, сохраняя соотношение сторон, и округляет до чисел, кратных 8."""
    aspect_ratio = width / height
    if width > max_width or height > max_height:
        if width / max_width > height / max_height:
            width = max_width
            height = int(width / aspect_ratio)
        else:
            height = max_height
            width = int(height * aspect_ratio)
    
    # Округляем до ближайших чисел, кратных 8
    width = (width + 3) // 8 * 8
    height = (height + 3) // 8 * 8
    return width, height

@cached(cache, key=lambda prompt, width, height, seed, model_name: generate_cache_key(prompt, width, height, seed, model_name))
def generate_cached_image(prompt, width, height, seed, model_name, api_key):
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    data = {
        "inputs": prompt,
        "parameters": {
            "width": width,
            "height": height,
            "seed": seed
        }
    }

    try:
        response = requests.post(
            f"https://api-inference.huggingface.co/models/{model_name}",
            headers=headers,
            json=data,
            timeout=1550  # Таймаут 3 минуты
        )
        response.raise_for_status()
        image_data = response.content
        image = Image.open(io.BytesIO(image_data))
        return image
    except requests.exceptions.HTTPError as http_err:
        app.logger.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
        return None
    except requests.exceptions.Timeout as timeout_err:
        app.logger.error(f"Timeout error occurred: {timeout_err}")
        return None
    except requests.exceptions.RequestException as req_err:
        app.logger.error(f"Request error occurred: {req_err}")
        return None

@app.route('/prompt/<path:prompt>')
def get_image(prompt):
    width = request.args.get('width', type=int, default=512)
    height = request.args.get('height', type=int, default=512)
    seed = request.args.get('seed', type=int, default=25)
    model_name = request.args.get('model', default="black-forest-labs/FLUX.1-schnell").replace('+', '/')
    api_key = request.args.get('key', default=None)

    # Декодируем URL-кодированный prompt
    prompt = urllib.parse.unquote(prompt)

    # Определяем язык промпта
    try:
        language = detect(prompt)
    except Exception as e:
        app.logger.error(f"Error detecting language: {e}")
        return send_error_image()

    # Переводим промпт, если он не на английском языке
    if language != 'en':
        try:
            translator = GoogleTranslator(source=language, target='en')
            prompt = translator.translate(prompt)
        except Exception as e:
            app.logger.error(f"Error translating prompt: {e}")
            return send_error_image()

    # Масштабируем размеры изображения, если они превышают максимальные значения, и округляем до чисел, кратных 8
    width, height = scale_dimensions(width, height, MAX_WIDTH, MAX_HEIGHT)

    # Используем указанный ключ, если он предоставлен, иначе выбираем случайный ключ
    if api_key is None:
        api_key = get_random_key()

    try:
        image = generate_cached_image(prompt, width, height, seed, model_name, api_key)
        if image is None:
            return send_error_image()
    except Exception as e:
        app.logger.error(f"Error generating image: {e}")
        return send_error_image()

    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()

    return send_file(
        io.BytesIO(img_byte_arr),
        mimetype='image/png'
    )

@app.route('/')
def health_check():
    return "OK", 200

def send_error_image():
    error_image_url = "https://raw.githubusercontent.com/Igroshka/-/refs/heads/main/img/nuai/errorimg.png"
    try:
        response = requests.get(error_image_url)
        response.raise_for_status()
        error_image = Image.open(io.BytesIO(response.content))
        img_byte_arr = io.BytesIO()
        error_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        return send_file(
            io.BytesIO(img_byte_arr),
            mimetype='image/png'
        )
    except Exception as e:
        app.logger.error(f"Error fetching error image: {e}")
        abort(500, description="Error fetching error image")

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=False)