soiz commited on
Commit
92b8b0d
·
verified ·
1 Parent(s): 2cfd327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -37
app.py CHANGED
@@ -1,31 +1,43 @@
1
- from flask import Flask, request, jsonify, send_file, render_template_string, make_response
2
- import requests
3
- import io
4
  import os
5
  import random
 
6
  from PIL import Image
7
  from deep_translator import GoogleTranslator
 
 
 
8
 
9
- app = Flask(__name__)
 
 
10
 
11
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
12
-
13
  API_TOKEN = os.getenv("HF_READ_TOKEN")
14
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
15
  timeout = 50000 # タイムアウトを300秒に設定
16
 
17
- # Function to query the API and return the generated image
18
- def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, num_inference_steps=30, guidance_scale=7.5, top_k=50, top_p=0.9, eta=0.1):
 
 
 
 
 
 
 
 
 
 
 
19
  if not prompt:
20
  return None, "Prompt is required"
21
 
22
  key = random.randint(0, 999)
23
-
24
  # Translate the prompt from Russian to English if necessary
25
  prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
26
  print(f'Generation {key} translation: {prompt}')
27
 
28
- # Add some extra flair to the prompt
29
  prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
30
  print(f'Generation {key}: {prompt}')
31
 
@@ -48,28 +60,24 @@ def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M K
48
  }
49
  }
50
 
51
- for attempt in range(3): # 最大3回の再試行
52
  try:
53
- response = requests.post(API_URL, json=payload, headers=headers, timeout=timeout)
54
- if response.status_code != 200:
55
- return None, f"Error: Failed to get image. Status code: {response.status_code}, Details: {response.text}"
56
-
57
- image_bytes = response.content
58
- image = Image.open(io.BytesIO(image_bytes))
59
- return image, None
60
- except requests.exceptions.Timeout:
61
- if attempt < 2: # 最後の試行でない場合は再試行
62
- print("Timeout occurred, retrying...")
63
- continue
64
  return None, "Error: The request timed out. Please try again."
65
- except requests.exceptions.RequestException as e:
66
- return None, f"Request Exception: {str(e)}"
67
  except Exception as e:
68
- return None, f"Error when trying to open the image: {e}"
 
69
 
70
  # Content-Security-Policyヘッダーを設定するための関数
71
  @app.after_request
72
- def add_security_headers(response):
73
  response.headers['Content-Security-Policy'] = (
74
  "default-src 'self'; "
75
  "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
@@ -79,22 +87,30 @@ def add_security_headers(response):
79
  )
80
  return response
81
 
82
- # HTML template for the index page
83
  index_html = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  """
85
 
86
  @app.route('/')
87
- def index():
88
- return render_template_string(index_html)
89
 
90
  @app.route('/generate', methods=['GET'])
91
- def generate_image():
92
- if request.headers.getlist("X-Forwarded-For"):
93
- client_ip = request.headers.getlist("X-Forwarded-For")[0]
94
- else:
95
- client_ip = request.remote_addr
96
-
97
- print(f"Client IP: {client_ip}")
98
  prompt = request.args.get("prompt", "")
99
  negative_prompt = request.args.get("negative_prompt", "")
100
  steps = int(request.args.get("steps", 35))
@@ -110,7 +126,12 @@ def generate_image():
110
  top_p = float(request.args.get("top_p", 0.9))
111
  eta = float(request.args.get("eta", 0.1))
112
 
113
- image, error = query(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height, num_inference_steps, guidance_scale, top_k, top_p, eta)
 
 
 
 
 
114
 
115
  if error:
116
  return jsonify({"error": error}), 400
@@ -118,7 +139,11 @@ def generate_image():
118
  img_bytes = io.BytesIO()
119
  image.save(img_bytes, format='PNG')
120
  img_bytes.seek(0)
121
- return send_file(img_bytes, mimetype='image/png')
 
 
 
 
122
 
123
  if __name__ == "__main__":
124
  app.run(host='0.0.0.0', port=7860)
 
 
 
 
1
  import os
2
  import random
3
+ import io
4
  from PIL import Image
5
  from deep_translator import GoogleTranslator
6
+ import aiohttp
7
+ from quart import Quart, request, jsonify, send_file, render_template_string
8
+ from werkzeug.contrib.cache import SimpleCache
9
 
10
+ # アプリケーションの設定
11
+ app = Quart(__name__)
12
+ cache = SimpleCache()
13
 
14
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
 
15
  API_TOKEN = os.getenv("HF_READ_TOKEN")
16
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
17
  timeout = 50000 # タイムアウトを300秒に設定
18
 
19
+ # キャッシュの設定
20
+ def get_cached_image(prompt):
21
+ cached = cache.get(prompt)
22
+ if cached is not None:
23
+ return cached
24
+ return None
25
+
26
+ def set_cache_image(prompt, image):
27
+ cache.set(prompt, image, timeout=60*60*24) # キャッシュを24時間保存
28
+
29
+
30
+ # 非同期APIリクエストの実行関数
31
+ async def query_async(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, num_inference_steps=30, guidance_scale=7.5, top_k=50, top_p=0.9, eta=0.1):
32
  if not prompt:
33
  return None, "Prompt is required"
34
 
35
  key = random.randint(0, 999)
36
+
37
  # Translate the prompt from Russian to English if necessary
38
  prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
39
  print(f'Generation {key} translation: {prompt}')
40
 
 
41
  prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
42
  print(f'Generation {key}: {prompt}')
43
 
 
60
  }
61
  }
62
 
63
+ async with aiohttp.ClientSession() as session:
64
  try:
65
+ async with session.post(API_URL, json=payload, headers=headers, timeout=timeout) as response:
66
+ if response.status != 200:
67
+ return None, f"Error: Failed to get image. Status code: {response.status}, Details: {await response.text()}"
68
+
69
+ image_bytes = await response.read()
70
+ image = Image.open(io.BytesIO(image_bytes))
71
+ return image, None
72
+ except asyncio.TimeoutError:
 
 
 
73
  return None, "Error: The request timed out. Please try again."
 
 
74
  except Exception as e:
75
+ return None, f"Request Exception: {str(e)}"
76
+
77
 
78
  # Content-Security-Policyヘッダーを設定するための関数
79
  @app.after_request
80
+ async def add_security_headers(response):
81
  response.headers['Content-Security-Policy'] = (
82
  "default-src 'self'; "
83
  "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
 
87
  )
88
  return response
89
 
90
+ # HTMLテンプレート
91
  index_html = """
92
+ <!DOCTYPE html>
93
+ <html lang="en">
94
+ <head>
95
+ <meta charset="UTF-8">
96
+ <title>Image Generator</title>
97
+ </head>
98
+ <body>
99
+ <h1>Welcome to the Image Generator</h1>
100
+ <form action="/generate" method="get">
101
+ <input type="text" name="prompt" placeholder="Enter prompt" required>
102
+ <input type="submit" value="Generate">
103
+ </form>
104
+ </body>
105
+ </html>
106
  """
107
 
108
  @app.route('/')
109
+ async def index():
110
+ return await render_template_string(index_html)
111
 
112
  @app.route('/generate', methods=['GET'])
113
+ async def generate_image():
 
 
 
 
 
 
114
  prompt = request.args.get("prompt", "")
115
  negative_prompt = request.args.get("negative_prompt", "")
116
  steps = int(request.args.get("steps", 35))
 
126
  top_p = float(request.args.get("top_p", 0.9))
127
  eta = float(request.args.get("eta", 0.1))
128
 
129
+ # キャッシュを確認
130
+ cached_image = get_cached_image(prompt)
131
+ if cached_image:
132
+ return await send_file(cached_image, mimetype='image/png')
133
+
134
+ image, error = await query_async(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height, num_inference_steps, guidance_scale, top_k, top_p, eta)
135
 
136
  if error:
137
  return jsonify({"error": error}), 400
 
139
  img_bytes = io.BytesIO()
140
  image.save(img_bytes, format='PNG')
141
  img_bytes.seek(0)
142
+
143
+ # 画像をキャッシュに保存
144
+ set_cache_image(prompt, img_bytes)
145
+
146
+ return await send_file(img_bytes, mimetype='image/png')
147
 
148
  if __name__ == "__main__":
149
  app.run(host='0.0.0.0', port=7860)