Spaces:
Runtime error
Runtime error
import discord | |
from discord.ext import commands | |
import requests | |
import os | |
import spaces | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from threading import Thread | |
import asyncio | |
TOKEN = os.environ["TOKEN"] | |
client = commands.Bot(command_prefix='?', intents=discord.Intents.all()) | |
torch.set_default_device("cuda") | |
pipe = pipeline("text-generation", model="nroggendorff/mayo") | |
history = [] | |
def append(content, creator): | |
if creator == "user": | |
tmp = [content, None] | |
history.append(tmp) | |
elif creator == "assistant": | |
history[-1][1] = content | |
def predict(input_text, history): | |
conv = [{"role": "system", "content": "You are discord user Jimmy Donaldson. You are an american youtuber by the name of MrBeast. Your sole purpose in life is to get people to laugh at you. You make innapropriate jokes, and aren't very funny."}] | |
for item in history: | |
conv.append({"role": "user", "content": item[0]}) | |
conv.append({"role": "assistant", "content": item[1]}) | |
conv.append({"role": "user", "content": input_text}) | |
generated_text = pipe(conv, max_new_tokens=1024)[0]['generated_text'][-1]['content'] | |
return generated_text | |
def generate(content): | |
append(content, "user") | |
response = predict(content, history) | |
append(response, "assistant") | |
return response | |
async def respond(channel, message): | |
await channel.send(message) | |
async def on_ready(): | |
print(f"We logged in bitches! - {client.user}") | |
async def on_message(ctx): | |
if ctx.author == client.user: | |
return | |
if ctx.guild is None: | |
await process_dm(ctx) | |
else: | |
if client.user in ctx.mentions: | |
await process_mention(ctx) | |
await client.process_commands(ctx) | |
async def process_dm(message): | |
content = message.content.replace(f"<@{client.user.id}>", "").strip() | |
async with message.channel.typing(): | |
response = generate(content) | |
await respond(message.channel, response) | |
print(message.author) | |
print(content) | |
print(response) | |
async def process_mention(message): | |
content = message.content.replace(f"<@{client.user.id}>", "").strip() | |
async with message.channel.typing(): | |
response = generate(content) | |
await respond(message.channel, response) | |
print(message.author) | |
print(content) | |
print(response) | |
messages = [] | |
def refresh(new_message, markdown_text): | |
global messages | |
messages.append(new_message) | |
display_text = "\n".join(messages) | |
return display_text, markdown_text | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
message_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
display_area = gr.Text(label="Messages", interactive=False) | |
markdown_area = gr.Markdown("# This app is community funded! Any message sent will appear everywhere. To help support more projects like this, my cash tag is $Noa087!") | |
submit_button = gr.Button("Send") | |
submit_button.click(refresh, inputs=[message_input], outputs=[display_area]) | |
def rungradio(): | |
demo.launch() | |
def runclient(): | |
client.run(TOKEN) | |
async def main(): | |
loop = asyncio.get_running_loop() | |
gradio_task = loop.run_in_executor(None, rungradio) | |
discord_task = asyncio.create_task(runclient()) | |
await asyncio.gather(gradio_task, discord_task) |