Spaces:
Runtime error
Runtime error
import discord | |
import logging | |
import os | |
import uuid | |
import torch | |
import subprocess | |
from huggingface_hub import snapshot_download | |
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline | |
from transformers import pipeline | |
# λ‘κΉ μ€μ | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()]) | |
# μΈν νΈ μ€μ | |
intents = discord.Intents.default() | |
intents.message_content = True | |
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ | |
huggingface_token = os.getenv("HF_TOKEN") | |
model_path = snapshot_download( | |
repo_id="stabilityai/stable-diffusion-3-medium", | |
revision="refs/pr/26", | |
repo_type="model", | |
ignore_patterns=[".md", "..gitattributes"], | |
local_dir="stable-diffusion-3-medium", | |
token=huggingface_token, | |
) | |
# λͺ¨λΈ λ‘λ ν¨μ | |
def load_pipeline(pipeline_type): | |
logging.debug(f'Loading pipeline: {pipeline_type}') | |
if pipeline_type == "text2img": | |
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True) | |
elif pipeline_type == "img2img": | |
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True) | |
# λλ°μ΄μ€ μ€μ | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# λ²μ νμ΄νλΌμΈ μ€μ | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
# λμ€μ½λ λ΄ ν΄λμ€ | |
class MyClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.is_processing = False | |
self.text2img_pipeline = load_pipeline("text2img").to(device) | |
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν | |
async def on_ready(self): | |
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!') | |
subprocess.Popen(["python", "web.py"]) | |
logging.info("web.py μλ²κ° μμλμμ΅λλ€.") | |
async def on_message(self, message): | |
if message.author == self.user: | |
return | |
if message.content.startswith('!image '): | |
self.is_processing = True | |
try: | |
prompt = message.content[len('!image '):] | |
prompt_en = translate_prompt(prompt) | |
logging.debug(f'Translated prompt: {prompt_en}') | |
image_path = await self.generate_image(prompt_en) | |
await message.channel.send(file=discord.File(image_path, 'generated_image.png')) | |
finally: | |
self.is_processing = False | |
async def generate_image(self, prompt): | |
generator = torch.Generator(device=device).manual_seed(torch.seed()) | |
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"] | |
image_path = f'/tmp/{uuid.uuid4()}.png' | |
images[0].save(image_path) | |
return image_path | |
# ν둬ννΈ λ²μ ν¨μ | |
def translate_prompt(prompt): | |
logging.debug(f'Translating prompt: {prompt}') | |
translation = translator(prompt, max_length=512) | |
translated_text = translation[0]['translation_text'] | |
logging.debug(f'Translated text: {translated_text}') | |
return translated_text | |
# λμ€μ½λ ν ν° λ° λ΄ μ€ν | |
if __name__ == "__main__": | |
discord_token = os.getenv('DISCORD_TOKEN') | |
discord_client = MyClient(intents=intents) | |
discord_client.run(discord_token) | |