seawolf2357 commited on
Commit
7262aa5
โ€ข
1 Parent(s): 12bb502

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  from huggingface_hub import InferenceClient
5
  import asyncio
6
  import subprocess
 
 
7
 
8
  # ๋กœ๊น… ์„ค์ •
9
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
@@ -24,6 +26,10 @@ SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))
24
  # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ์ €์žฅํ•  ์ „์—ญ ๋ณ€์ˆ˜
25
  conversation_history = []
26
 
 
 
 
 
27
  class MyClient(discord.Client):
28
  def __init__(self, *args, **kwargs):
29
  super().__init__(*args, **kwargs)
@@ -41,15 +47,31 @@ class MyClient(discord.Client):
41
  return
42
  if self.is_processing:
43
  return
44
- self.is_processing = True
45
- try:
46
- response = await generate_response(message)
47
- await message.channel.send(response)
48
- finally:
49
- self.is_processing = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  async def generate_response(message):
52
- global conversation_history # ์ „์—ญ ๋ณ€์ˆ˜ ์‚ฌ์šฉ์„ ๋ช…์‹œ
53
  user_input = message.content
54
  user_mention = message.author.mention
55
  system_message = f"{user_mention}, DISCORD์—์„œ ์‚ฌ์šฉ์ž๋“ค์˜ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค."
 
4
  from huggingface_hub import InferenceClient
5
  import asyncio
6
  import subprocess
7
+ import torch
8
+ from diffusers import StableDiffusionPipeline
9
 
10
  # ๋กœ๊น… ์„ค์ •
11
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
 
26
  # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ์ €์žฅํ•  ์ „์—ญ ๋ณ€์ˆ˜
27
  conversation_history = []
28
 
29
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ชจ๋ธ ๋กœ๋“œ
30
+ if torch.cuda.is_available():
31
+ model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16).to("cuda")
32
+
33
  class MyClient(discord.Client):
34
  def __init__(self, *args, **kwargs):
35
  super().__init__(*args, **kwargs)
 
47
  return
48
  if self.is_processing:
49
  return
50
+ if message.content.startswith('!image '):
51
+ self.is_processing = True
52
+ try:
53
+ prompt = message.content[len('!image '):] # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ ํŒŒ์‹ฑ
54
+ image_path = await generate_image(prompt)
55
+ await message.channel.send(file=discord.File(image_path, 'generated_image.png'))
56
+ finally:
57
+ self.is_processing = False
58
+ else:
59
+ self.is_processing = True
60
+ try:
61
+ response = await generate_response(message)
62
+ await message.channel.send(response)
63
+ finally:
64
+ self.is_processing = False
65
+
66
+ async def generate_image(prompt):
67
+ generator = torch.Generator(device="cuda").manual_seed(torch.seed())
68
+ image = model(prompt, num_inference_steps=50, generator=generator)["sample"][0]
69
+ image_path = '/tmp/generated_image.png'
70
+ image.save(image_path)
71
+ return image_path
72
 
73
  async def generate_response(message):
74
+ global conversation_history
75
  user_input = message.content
76
  user_mention = message.author.mention
77
  system_message = f"{user_mention}, DISCORD์—์„œ ์‚ฌ์šฉ์ž๋“ค์˜ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค."