#for some reason the status of this demo is 'undefined'
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor")
def round_to_2(num):
return round(num, 2)
def format_output(out_list):
if len(out_list) == 1:
out_list = out_list[0]
for dictionary in out_list:
if dictionary["label"] =='W':
white = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='U':
blue = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='B':
black = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='R':
red = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='G':
green = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='C':
colorless = round_to_2(dictionary["score"])
out= {}
out['White'] = white
out['Blue'] = blue
out['Black'] = black
out['Red'] = red
out['Green'] = green
out['Colorless'] = colorless
return out
def predict(card):
return predictor_lg(card)
def remove_colored_pips(text):
pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}'
return(re.sub(pattern, '{?}', text))
def preprocess_text(text):
return remove_colored_pips(text)
def categorize(Card):
text = preprocess_text(Card)
prediction = predict(text)
print(prediction)
return format_output(prediction)
title = "Color pAI Version 1.0"
description = """
Color pAI is trained on around 18,000 Magic: the Gathering cards.
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
Replace any card names with the word CARDNAME, and mana symbols with the uppercase letter encased U in curly brackets {U}
This only works on monocolored and colorless cards.
"""
article = '''
Magic: the Gathering is property of Wizards of the Coast. This project is made possible under their
fan content policy.
'''
predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6)
gr.Interface(
fn=categorize,
inputs=gr.Textbox(lines=1, placeholder="Type card text here."),
outputs=gr.Label(num_top_classes=6),
title=title,
description=description,
article = article,
).launch()