Spaces:
Runtime error
Runtime error
import discord | |
import logging | |
import os | |
import asyncio | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import re | |
import requests | |
from PIL import Image | |
import io | |
import gradio as gr | |
import threading | |
from huggingface_hub import InferenceClient | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s:%(message)s', handlers=[logging.StreamHandler()]) | |
# ๋์ค์ฝ๋ ์ธํ ํธ ์ค์ | |
intents = discord.Intents.default() | |
intents.message_content = True | |
intents.messages = True | |
intents.guilds = True | |
intents.guild_messages = True | |
# ์ถ๋ก API ํด๋ผ์ด์ธํธ ์ค์ | |
hf_client = InferenceClient("CohereForAI/aya-23-35B", token=os.getenv("HF_TOKEN")) | |
# PaliGemma ๋ชจ๋ธ ์ค์ (CPU ๋ชจ๋) | |
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval() | |
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") | |
# ๋ํ ํ์คํ ๋ฆฌ๋ฅผ ์ ์ฅํ ์ ์ญ ๋ณ์ | |
conversation_history = [] | |
def modify_caption(caption: str) -> str: | |
prefix_substrings = [ | |
('captured from ', ''), | |
('captured at ', '') | |
] | |
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
def replace_fn(match): | |
return replacers[match.group(0)] | |
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) | |
async def create_captions_rich(image: Image.Image) -> str: | |
prompt = "caption en" | |
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu") | |
image_tensor = (image_tensor * 255).type(torch.uint8) | |
model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu") | |
input_len = model_inputs["input_ids"].shape[-1] | |
loop = asyncio.get_event_loop() | |
generation = await loop.run_in_executor( | |
None, | |
lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False) | |
) | |
generation = generation[0][input_len:] | |
decoded = processor.decode(generation, skip_special_tokens=True) | |
modified_caption = modify_caption(decoded) | |
return modified_caption | |
async def translate_to_korean(text: str) -> str: | |
messages = [ | |
{"role": "system", "content": "Translate the following text from English to Korean."}, | |
{"role": "user", "content": text} | |
] | |
loop = asyncio.get_event_loop() | |
response = await loop.run_in_executor( | |
None, | |
lambda: hf_client.chat_completion( | |
messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85 | |
) | |
) | |
full_response = [] | |
for part in response: | |
if part.choices and part.choices[0].delta and part.choices[0].delta.content: | |
full_response.append(part.choices[0].delta.content) | |
full_response_text = ''.join(full_response) | |
return full_response_text.strip() | |
async def interact_with_model(user_input: str) -> str: | |
global conversation_history | |
conversation_history.append({"role": "user", "content": user_input}) | |
messages = [ | |
{"role": "system", "content": "Translate the following text from English to Korean and respond as if you are an assistant who provides detailed answers in Korean."}, | |
] + conversation_history | |
loop = asyncio.get_event_loop() | |
response = await loop.run_in_executor( | |
None, | |
lambda: hf_client.chat_completion( | |
messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85 | |
) | |
) | |
full_response = [] | |
for part in response: | |
if part.choices and part.choices[0].delta and part.choices[0].delta.content: | |
full_response.append(part.choices[0].delta.content) | |
full_response_text = ''.join(full_response) | |
conversation_history.append({"role": "assistant", "content": full_response_text}) | |
return full_response_text.strip() | |
# Gradio ์ธํฐํ์ด์ค ์ค์ | |
def create_captions_rich_sync(image): | |
caption = asyncio.run(create_captions_rich(image)) | |
translated_caption = asyncio.run(translate_to_korean(caption)) | |
return translated_caption | |
css = """ | |
#mkd { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML("<h1><center>PaliGemma Fine-tuned for Long Captioning<center><h1>") | |
with gr.Tab(label="PaliGemma Long Captioner"): | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input Picture") | |
submit_btn = gr.Button(value="Submit") | |
output = gr.Text(label="Caption") | |
submit_btn.click(create_captions_rich_sync, [input_img], [output]) | |
# Gradio ์๋ฒ๋ฅผ ๋น๋๊ธฐ์ ์ผ๋ก ์คํ | |
def run_gradio(): | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=int(os.getenv("GRADIO_SERVER_PORT", 7861)), | |
inbrowser=True | |
) | |
# ํน์ ์ฑ๋ ID ์ค์ | |
SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID", "123456789012345678")) | |
# ๋์ค์ฝ๋ ๋ด ์ค์ | |
class MyClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.is_processing = False | |
async def on_ready(self): | |
logging.info(f'{self.user}๋ก ๋ก๊ทธ์ธ๋์์ต๋๋ค!') | |
threading.Thread(target=run_gradio, daemon=True).start() | |
logging.info("Gradio ์๋ฒ๊ฐ ์์๋์์ต๋๋ค.") | |
async def on_message(self, message): | |
if message.author == self.user: | |
return | |
if not self.is_message_in_specific_channel(message): | |
return | |
if self.is_processing: | |
return | |
self.is_processing = True | |
try: | |
if message.attachments: | |
image_url = message.attachments[0].url | |
response = await process_image(image_url, message) | |
await message.channel.send(response) | |
else: | |
response = await interact_with_model(message.content) | |
await message.channel.send(response) | |
finally: | |
self.is_processing = False | |
def is_message_in_specific_channel(self, message): | |
return message.channel.id == SPECIFIC_CHANNEL_ID or ( | |
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID | |
) | |
async def process_image(image_url, message): | |
image = await download_image(image_url) | |
caption = await create_captions_rich(image) | |
translated_caption = await translate_to_korean(caption) | |
intro_message = f"{message.author.mention}, ์ธ์๋ ์ด๋ฏธ์ง ์ค๋ช : {translated_caption}\n\n์ง๋ฌธ์ด ์์ผ๋ฉด ๋ฌผ์ด๋ณด์ธ์!" | |
return intro_message | |
async def download_image(url): | |
response = requests.get(url) | |
image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
return image | |
if __name__ == "__main__": | |
discord_client = MyClient(intents=intents) | |
discord_client.run(os.getenv('DISCORD_TOKEN')) | |