kai-image-xl / app.py
seawolf2357's picture
Update app.py
021392e verified
raw
history blame
No virus
3.42 kB
import discord
import logging
import os
import uuid
import torch
import subprocess
from huggingface_hub import snapshot_download
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
from transformers import pipeline
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν…νŠΈ μ„€μ •
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ‹€μš΄λ‘œλ“œ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=[".md", "..gitattributes"],
local_dir="stable-diffusion-3-medium",
token=huggingface_token,
)
# λͺ¨λΈ λ‘œλ“œ ν•¨μˆ˜
def load_pipeline(pipeline_type):
logging.debug(f'Loading pipeline: {pipeline_type}')
if pipeline_type == "text2img":
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
elif pipeline_type == "img2img":
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
# λ””λ°”μ΄μŠ€ μ„€μ •
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λ²ˆμ—­ νŒŒμ΄ν”„λΌμΈ μ„€μ •
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# λ””μŠ€μ½”λ“œ 봇 클래슀
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©”λͺ¨λ¦¬ μ΅œμ ν™”
async def on_ready(self):
logging.info(f'{self.user}둜 λ‘œκ·ΈμΈλ˜μ—ˆμŠ΅λ‹ˆλ‹€!')
subprocess.Popen(["python", "web.py"])
logging.info("web.py μ„œλ²„κ°€ μ‹œμž‘λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
prompt_en = translate_prompt(prompt)
logging.debug(f'Translated prompt: {prompt_en}')
image_path = await self.generate_image(prompt_en)
await message.channel.send(file=discord.File(image_path, 'generated_image.png'))
finally:
self.is_processing = False
async def generate_image(self, prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# ν”„λ‘¬ν”„νŠΈ λ²ˆμ—­ ν•¨μˆ˜
def translate_prompt(prompt):
logging.debug(f'Translating prompt: {prompt}')
translation = translator(prompt, max_length=512)
translated_text = translation[0]['translation_text']
logging.debug(f'Translated text: {translated_text}')
return translated_text
# λ””μŠ€μ½”λ“œ 토큰 및 봇 μ‹€ν–‰
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)