File size: 5,177 Bytes
435f67e
 
92b8b0d
435f67e
 
92b8b0d
 
2c4be8e
435f67e
92b8b0d
 
2c4be8e
 
 
 
 
 
 
 
435f67e
d029a83
286f519
 
61fd9f3
435f67e
92b8b0d
 
435f67e
 
 
 
92b8b0d
435f67e
 
 
 
 
 
 
 
 
 
 
 
 
 
bae6956
f913d22
 
90dfa53
435f67e
 
c23c889
 
90dfa53
435f67e
 
 
92b8b0d
435f67e
92b8b0d
 
 
 
 
 
 
 
435f67e
 
92b8b0d
 
435f67e
700fd72
 
92b8b0d
700fd72
 
74d0aae
700fd72
 
 
 
 
1fb1361
92b8b0d
700fd72
92b8b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
435f67e
 
 
92b8b0d
 
435f67e
 
92b8b0d
435f67e
 
 
 
 
 
 
 
 
c23c889
 
 
 
 
 
92b8b0d
2c4be8e
92b8b0d
2c4be8e
92b8b0d
 
700fd72
435f67e
 
 
 
 
 
92b8b0d
 
2c4be8e
92b8b0d
 
435f67e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import random
import io
from PIL import Image
from deep_translator import GoogleTranslator
import aiohttp
from quart import Quart, request, jsonify, send_file, render_template_string
from flask_caching import Cache  # flask-cachingを使用

# アプリケーションの設定
app = Quart(__name__)

# キャッシュの設定
cache_config = {
    "CACHE_TYPE": "SimpleCache",  # メモリベースのシンプルなキャッシュ
    "CACHE_DEFAULT_TIMEOUT": 60 * 60 * 24  # 24時間
}
app.config.from_mapping(cache_config)
cache = Cache(app)

API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
API_TOKEN = os.getenv("HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
timeout = 50000  # タイムアウトを300秒に設定

# 非同期APIリクエストの実行関数
async def query_async(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, num_inference_steps=30, guidance_scale=7.5, top_k=50, top_p=0.9, eta=0.1):
    if not prompt:
        return None, "Prompt is required"

    key = random.randint(0, 999)

    # Translate the prompt from Russian to English if necessary
    prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
    print(f'Generation {key} translation: {prompt}')

    prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
    print(f'Generation {key}: {prompt}')
    
    payload = {
        "inputs": prompt,
        "is_negative": False,
        "steps": steps,
        "cfg_scale": cfg_scale,
        "seed": seed if seed != -1 else random.randint(1, 1000000000),
        "strength": strength,
        "negative_prompt": negative_prompt,
        "top_k": top_k,
        "top_p": top_p,
        "eta": eta,
        "parameters": {
            "width": width,
            "height": height,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale
        }
    }

    async with aiohttp.ClientSession() as session:
        try:
            async with session.post(API_URL, json=payload, headers=headers, timeout=timeout) as response:
                if response.status != 200:
                    return None, f"Error: Failed to get image. Status code: {response.status}, Details: {await response.text()}"

                image_bytes = await response.read()
                image = Image.open(io.BytesIO(image_bytes))
                return image, None
        except asyncio.TimeoutError:
            return None, "Error: The request timed out. Please try again."
        except Exception as e:
            return None, f"Request Exception: {str(e)}"


# Content-Security-Policyヘッダーを設定するための関数
@app.after_request
async def add_security_headers(response):
    response.headers['Content-Security-Policy'] = (
        "default-src 'self'; "
        "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
        "img-src 'self' data:; "
        "style-src 'self' 'unsafe-inline'; "
        "script-src 'self' 'unsafe-inline'; "
    )
    return response

# HTMLテンプレート
index_html = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Image Generator</title>
</head>
<body>
    <h1>Welcome to the Image Generator</h1>
    <form action="/generate" method="get">
        <input type="text" name="prompt" placeholder="Enter prompt" required>
        <input type="submit" value="Generate">
    </form>
</body>
</html>
"""

@app.route('/')
async def index():
    return await render_template_string(index_html)

@app.route('/generate', methods=['GET'])
async def generate_image():
    prompt = request.args.get("prompt", "")
    negative_prompt = request.args.get("negative_prompt", "")
    steps = int(request.args.get("steps", 35))
    cfg_scale = float(request.args.get("cfgs", 7))
    sampler = request.args.get("sampler", "DPM++ 2M Karras")
    strength = float(request.args.get("strength", 0.7))
    seed = int(request.args.get("seed", -1))
    width = int(request.args.get("width", 1024))
    height = int(request.args.get("height", 1024))
    num_inference_steps = int(request.args.get("num_inference_steps", 30))
    guidance_scale = float(request.args.get("guidance_scale", 7.5))
    top_k = int(request.args.get("top_k", 50))
    top_p = float(request.args.get("top_p", 0.9))
    eta = float(request.args.get("eta", 0.1))

    # キャッシュを確認
    cached_image = cache.get(prompt)
    if cached_image:
        return await send_file(io.BytesIO(cached_image), mimetype='image/png')

    image, error = await query_async(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height, num_inference_steps, guidance_scale, top_k, top_p, eta)

    if error:
        return jsonify({"error": error}), 400

    img_bytes = io.BytesIO()
    image.save(img_bytes, format='PNG')
    img_bytes.seek(0)

    # 画像をキャッシュに保存
    cache.set(prompt, img_bytes.getvalue())

    return await send_file(img_bytes, mimetype='image/png')

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)