ColorpAI / app.py
Qilex's picture
Update app.py
c597797
raw
history blame
2.8 kB
#for some reason the status of this demo is 'undefined'
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import re
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForSequenceClassification.from_pretrained("Qilex/colorpAI-monocolor")
def round_to_2(num):
return round(num, 2)
def format_output(out_list):
if len(out_list) == 1:
out_list = out_list[0]
for dictionary in out_list:
if dictionary["label"] =='W':
white = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='U':
blue = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='B':
black = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='R':
red = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='G':
green = round_to_2(dictionary["score"])
for dictionary in out_list:
if dictionary["label"] =='C':
colorless = round_to_2(dictionary["score"])
out= {}
out['White'] = white
out['Blue'] = blue
out['Black'] = black
out['Red'] = red
out['Green'] = green
out['Colorless'] = colorless
return out
def predict(card):
return predictor_lg(card)
def remove_colored_pips(text):
pattern = r'\{[W,U,B,R,G,C]+/*[W,U,B,R,G,C]*\}'
return(re.sub(pattern, '{?}', text))
def preprocess_text(text):
return remove_colored_pips(text)
def categorize(Card):
text = preprocess_text(Card)
prediction = predict(text)
print(prediction)
return format_output(prediction)
title = "Color pAI Version 1.0"
description = """
Color pAI is trained on around 18,000 Magic: the Gathering cards.
<br>
Input a card text using Scryfall syntax, and the model will tell evaluate which color it is most likely to be.
<br>Replace any card names with the word CARDNAME, and mana symbols with the uppercase letter encased U in curly brackets {U}
<br>
<br>This only works on monocolored and colorless cards.
<br>
"""
article = '''
<br>
Magic: the Gathering is property of Wizards of the Coast. This project is made possible under their
<a href="https://company.wizards.com/en/legal/fancontentpolicy" target = 'blank'>fan content policy</a>.
'''
predictor_lg = TextClassificationPipeline(model=model, tokenizer=tokenizer, function_to_apply = 'softmax', top_k = 6)
gr.Interface(
fn=categorize,
inputs=gr.Textbox(lines=1, placeholder="Type card text here."),
outputs=gr.Label(num_top_classes=6),
title=title,
description=description,
article = article,
).launch()