soiz commited on
Commit
6a3ef3a
1 Parent(s): 95d6160

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -83
app.py CHANGED
@@ -1,80 +1,53 @@
1
  from flask import Flask, request, jsonify, send_file, render_template_string, make_response
2
- import requests
3
- import io
4
- import random
5
- from PIL import Image
6
  from deep_translator import GoogleTranslator
 
 
 
 
 
 
7
 
8
  app = Flask(__name__)
9
 
10
- API_URL = "https://api-inference.huggingface.co/models/Ojimi/anime-kawai-diffusion"
11
- timeout = 3000 # タイムアウトを300秒に設定
12
 
13
- # Function to query the API and return the generated image
14
- def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024):
15
- if not prompt:
16
- return None, "Prompt is required"
 
 
 
 
 
17
 
18
- key = random.randint(0, 999)
 
 
 
 
19
 
20
- # Translate the prompt from Russian to English if necessary
21
- prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
22
- print(f'Generation {key} translation: {prompt}')
23
 
24
- # Add some extra flair to the prompt
25
- prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
26
- print(f'Generation {key}: {prompt}')
27
-
28
- payload = {
29
- "inputs": prompt,
30
- "is_negative": False,
31
- "steps": steps,
32
- "cfg_scale": cfg_scale,
33
- "seed": seed if seed != -1 else random.randint(1, 1000000000),
34
- "strength": strength,
35
- "parameters": {
36
- "width": width,
37
- "height": height
38
- }
39
- }
40
-
41
- for attempt in range(3): # 最大3回の再試行
42
- try:
43
- # Authorization header is removed
44
- response = requests.post(API_URL, json=payload, timeout=timeout)
45
- if response.status_code != 200:
46
- return None, f"Error: Failed to get image. Status code: {response.status_code}, Details: {response.text}"
47
-
48
- image_bytes = response.content
49
- image = Image.open(io.BytesIO(image_bytes))
50
- return image, None
51
- except requests.exceptions.Timeout:
52
- if attempt < 2: # 最後の試行でない場合は再試行
53
- print("Timeout occurred, retrying...")
54
- continue
55
- return None, "Error: The request timed out. Please try again."
56
- except requests.exceptions.RequestException as e:
57
- return None, f"Request Exception: {str(e)}"
58
- except Exception as e:
59
- return None, f"Error when trying to open the image: {e}"
60
-
61
- # Content-Security-Policyヘッダーを設定するための関数
62
- @app.after_request
63
- def add_security_headers(response):
64
- response.headers['Content-Security-Policy'] = (
65
- "default-src 'self'; "
66
- "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
67
- "img-src 'self' data:; "
68
- "style-src 'self' 'unsafe-inline'; "
69
- "script-src 'self' 'unsafe-inline'; "
70
- )
71
- return response
72
 
73
  # HTML template for the index page
74
  index_html = """
75
  <!DOCTYPE html>
76
  <html lang="ja">
77
- kawai diffusion
 
 
 
 
 
 
 
 
 
 
78
  </html>
79
  """
80
 
@@ -82,27 +55,51 @@ kawai diffusion
82
  def index():
83
  return render_template_string(index_html)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @app.route('/generate', methods=['GET'])
86
  def generate_image():
87
- prompt = request.args.get("prompt", "")
88
- negative_prompt = request.args.get("negative_prompt", "")
89
- steps = int(request.args.get("steps", 35))
90
- cfg_scale = float(request.args.get("cfgs", 7))
91
- sampler = request.args.get("sampler", "DPM++ 2M Karras")
92
- strength = float(request.args.get("strength", 0.7))
93
- seed = int(request.args.get("seed", -1))
94
- width = int(request.args.get("width", 1024))
95
- height = int(request.args.get("height", 1024))
96
-
97
- image, error = query(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height)
98
-
99
- if error:
100
- return jsonify({"error": error}), 400
101
-
102
- img_bytes = io.BytesIO()
103
- image.save(img_bytes, format='PNG')
104
- img_bytes.seek(0)
105
- return send_file(img_bytes, mimetype='image/png')
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  if __name__ == "__main__":
108
- app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify, send_file, render_template_string, make_response
 
 
 
 
2
  from deep_translator import GoogleTranslator
3
+ from PIL import Image
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ import random
7
+ import io
8
+ import os
9
 
10
  app = Flask(__name__)
11
 
12
+ MODEL_NAME = "Ojimi/anime-kawai-diffusion"
13
+ MODEL_DIR = "./models/anime-kawai-diffusion" # Directory to store the model
14
 
15
+ # Download and load the model at startup
16
+ def load_model():
17
+ if not os.path.exists(MODEL_DIR):
18
+ print(f"Downloading the model {MODEL_NAME}...")
19
+ pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
20
+ pipeline.save_pretrained(MODEL_DIR)
21
+ else:
22
+ print(f"Loading model from {MODEL_DIR}...")
23
+ pipeline = StableDiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.float16)
24
 
25
+ if torch.cuda.is_available():
26
+ pipeline.to("cuda")
27
+ print("Model loaded on GPU")
28
+ else:
29
+ print("GPU not available. Running on CPU.")
30
 
31
+ return pipeline
 
 
32
 
33
+ # Load the model once during startup
34
+ pipeline = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # HTML template for the index page
37
  index_html = """
38
  <!DOCTYPE html>
39
  <html lang="ja">
40
+ <head>
41
+ <title>Kawaii Diffusion</title>
42
+ </head>
43
+ <body>
44
+ <h1>Kawaii Diffusion Image Generator</h1>
45
+ <form action="/generate" method="get">
46
+ <label for="prompt">Prompt:</label>
47
+ <input type="text" id="prompt" name="prompt" required><br><br>
48
+ <button type="submit">Generate Image</button>
49
+ </form>
50
+ </body>
51
  </html>
52
  """
53
 
 
55
  def index():
56
  return render_template_string(index_html)
57
 
58
+ # Function to generate image locally
59
+ def generate_image_locally(prompt, steps=35, cfg_scale=7, width=512, height=512, seed=-1):
60
+ # Translate prompt from Russian to English
61
+ prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
62
+ print(f'Translated prompt: {prompt}')
63
+
64
+ # Set a random seed if not provided
65
+ generator = torch.manual_seed(seed if seed != -1 else random.randint(1, 1_000_000))
66
+
67
+ # Generate the image using the loaded pipeline
68
+ image = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator).images[0]
69
+ return image
70
+
71
  @app.route('/generate', methods=['GET'])
72
  def generate_image():
73
+ try:
74
+ prompt = request.args.get("prompt", "")
75
+ steps = int(request.args.get("steps", 35))
76
+ cfg_scale = float(request.args.get("cfgs", 7))
77
+ width = int(request.args.get("width", 512))
78
+ height = int(request.args.get("height", 512))
79
+ seed = int(request.args.get("seed", -1))
80
+
81
+ # Generate the image locally
82
+ image = generate_image_locally(prompt, steps, cfg_scale, width, height, seed)
83
+
84
+ # Save the image to a BytesIO object
85
+ img_bytes = io.BytesIO()
86
+ image.save(img_bytes, format='PNG')
87
+ img_bytes.seek(0)
88
+
89
+ return send_file(img_bytes, mimetype='image/png')
90
+ except Exception as e:
91
+ return jsonify({"error": str(e)}), 500
92
+
93
+ # Content-Security-Policyヘッダーを設定する
94
+ @app.after_request
95
+ def add_security_headers(response):
96
+ response.headers['Content-Security-Policy'] = (
97
+ "default-src 'self'; "
98
+ "img-src 'self' data:; "
99
+ "style-src 'self' 'unsafe-inline'; "
100
+ "script-src 'self' 'unsafe-inline'; "
101
+ )
102
+ return response
103
 
104
  if __name__ == "__main__":
105
+ app.run(host='0.0.0.0', port=7860)