from flask import Flask, request, jsonify, send_file, render_template_string, make_response
from deep_translator import GoogleTranslator
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline
import random
import io
import os
app = Flask(__name__)
MODEL_NAME = "Ojimi/anime-kawai-diffusion"
MODEL_DIR = "./models/anime-kawai-diffusion" # Directory to store the model
# Download and load the model at startup
def load_model():
if not os.path.exists(MODEL_DIR):
print(f"Downloading the model {MODEL_NAME}...")
pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
pipeline.save_pretrained(MODEL_DIR)
else:
print(f"Loading model from {MODEL_DIR}...")
pipeline = StableDiffusionPipeline.from_pretrained(MODEL_DIR, torch_dtype=torch.float16)
if torch.cuda.is_available():
pipeline.to("cuda")
print("Model loaded on GPU")
else:
print("GPU not available. Running on CPU.")
return pipeline
# Load the model once during startup
pipeline = load_model()
# HTML template for the index page
index_html = """
Kawaii Diffusion
Kawaii Diffusion Image Generator
"""
@app.route('/')
def index():
return render_template_string(index_html)
# Function to generate image locally
def generate_image_locally(prompt, steps=35, cfg_scale=7, width=512, height=512, seed=-1):
# Translate prompt from Russian to English
prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
print(f'Translated prompt: {prompt}')
# Set a random seed if not provided
generator = torch.manual_seed(seed if seed != -1 else random.randint(1, 1_000_000))
# Generate the image using the loaded pipeline
image = pipeline(prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator).images[0]
return image
@app.route('/generate', methods=['GET'])
def generate_image():
try:
prompt = request.args.get("prompt", "")
steps = int(request.args.get("steps", 35))
cfg_scale = float(request.args.get("cfgs", 7))
width = int(request.args.get("width", 512))
height = int(request.args.get("height", 512))
seed = int(request.args.get("seed", -1))
# Generate the image locally
image = generate_image_locally(prompt, steps, cfg_scale, width, height, seed)
# Save the image to a BytesIO object
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
img_bytes.seek(0)
return send_file(img_bytes, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
# Content-Security-Policyヘッダーを設定する
@app.after_request
def add_security_headers(response):
response.headers['Content-Security-Policy'] = (
"default-src 'self'; "
"img-src 'self' data:; "
"style-src 'self' 'unsafe-inline'; "
"script-src 'self' 'unsafe-inline'; "
)
return response
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)