Spaces:
Runtime error
Runtime error
# pip install html2image | |
import base64 | |
import random | |
from io import BytesIO | |
from html2image import Html2Image | |
import os | |
import pathlib | |
import re | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from gradio_client import Client | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Pipeline | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise Exception("HF_TOKEN environment variable is required to call remote API.") | |
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
client = Client("https://latent-consistency-super-fast-lcm-lora-sd1-5.hf.space") | |
def init_speech_to_text_model() -> Pipeline: | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "distil-whisper/distil-medium.en" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
whisper_pipe = init_speech_to_text_model() | |
def query(payload: dict): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def generate_text(card_text: str, user_request: str) -> (str, str, str): | |
# Prompt must apply the correct chat template for the model see: | |
# https://huggingface.co/docs/transformers/main/en/chat_templating | |
prompt = f"""<|system|> | |
You create Magic the Gathering cards based on the user's request. | |
# RULES | |
- In your response always generate a new card. | |
- Only generate one card, no other dialogue. | |
- Surround card info in triple backticks (```). | |
- Format the card text using headers like in the example below: | |
``` | |
Name: Band of Brothers | |
ManaCost: {{3}}{{W}}{{W}} | |
Type: Creature — Phyrexian Human Soldier | |
Rarity: rare | |
Text: Vigilance | |
{{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) | |
FlavorText: "This time we will be stronger." | |
—Elder brotherhood blessing | |
Power: 2 | |
Toughness: 2 | |
Color: ['W'] | |
```</s> | |
<|user|> | |
{user_request}</s> | |
<|assistant|> | |
""" | |
if card_text and card_text != starting_text: | |
prompt = f"""<|system|> | |
You edit Magic the Gathering cards based on the user's request. | |
# RULES | |
- In your response always generate a new card. | |
- Only generate one card, no other dialogue. | |
- Surround card info in triple backticks (```). | |
- Format the card text using headers like in the example below: | |
``` | |
Name: Band of Brothers | |
ManaCost: {{3}}{{W}}{{W}} | |
Type: Creature — Phyrexian Human Soldier | |
Rarity: rare | |
Text: Vigilance | |
{{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) | |
FlavorText: "This time we will be stronger." | |
—Elder brotherhood blessing | |
Power: 2 | |
Toughness: 2 | |
Color: ['W'] | |
```</s> | |
<|user|> | |
# CARD TO EDIT | |
``` | |
{card_text} | |
``` | |
# EDIT REQUEST | |
{user_request}</s> | |
<|assistant|> | |
""" | |
print(f"Calling API with prompt:\n{prompt}") | |
params = {"max_new_tokens": 512} | |
output = query({"inputs": prompt, "parameters": params}) | |
if 'error' in output: | |
print(f'Language model call failed: {output["error"]}') | |
raise gr.Warning(f'Language model call failed: {output["error"]}') | |
print(f'API RESPONSE SIZE: {len(output[0]["generated_text"])}') | |
assistant_reply = output[0]["generated_text"].split('<|assistant|>')[1] | |
print(f'ASSISTANT REPLY:\n{assistant_reply}') | |
new_card_text = assistant_reply.split('```') | |
if len(new_card_text) > 1: | |
new_card_text = new_card_text[1].strip() + '\n' | |
else: | |
new_card_text = assistant_reply.split('\n\n') | |
if len(new_card_text) < 2: | |
return assistant_reply, card_text, None | |
new_card_text = new_card_text[1].strip() + '\n' | |
return assistant_reply, new_card_text, None | |
def format_html(text, image_data): | |
template = pathlib.Path("./card_template.html").read_text(encoding='utf-8') | |
if "['U']" in text: | |
template = template.replace("{card_color}", 'style="background-color:#5a73ab"') | |
elif "['W']" in text: | |
template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') | |
elif "['G']" in text: | |
template = template.replace("{card_color}", 'style="background-color:#325433"') | |
elif "['B']" in text: | |
template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') | |
elif "['R']" in text: | |
template = template.replace("{card_color}", 'style="background-color:#c2401c"') | |
elif "Type: Land" in text: | |
template = template.replace("{card_color}", 'style="background-color:#aa8c71"') | |
elif "Type: Artifact" in text: | |
template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') | |
else: | |
template = template.replace("{card_color}", 'style="background-color:#edd99d"') | |
pattern = re.compile('Name: (.*)') | |
name = pattern.findall(text)[0] | |
template = template.replace("{name}", name) | |
pattern = re.compile('Mana.?Cost: (.*)') | |
mana_cost = pattern.findall(text)[0] | |
if mana_cost == "None": | |
template = template.replace("{mana_cost}", '<i class="ms ms-cost" style="visibility: hidden"></i>') | |
else: | |
symbols = [] | |
for c in mana_cost: | |
if c in {"{", "}"}: | |
continue | |
else: | |
symbols.append(c.lower()) | |
formatted_symbols = [] | |
for s in symbols: | |
formatted_symbols.append(f'<i class="ms ms-{s} ms-cost ms-shadow"></i>') | |
template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) | |
if not isinstance(image_data, (bytes, bytearray)): | |
template = template.replace('{image_data}', f'{image_data}') | |
else: | |
template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') | |
pattern = re.compile('Type: (.*)') | |
card_type = pattern.findall(text)[0] | |
template = template.replace("{card_type}", card_type) | |
if len(card_type) > 30: | |
template = template.replace("{type_size}", "16") | |
else: | |
template = template.replace("{type_size}", "18") | |
pattern = re.compile('Rarity: (.*)') | |
rarity = pattern.findall(text)[0] | |
template = template.replace("{rarity}", f"ss-{rarity}") | |
pattern = re.compile(r'^Text: (.*)\n\bFlavor.?Text|Power|Color\b', re.MULTILINE | re.DOTALL) | |
card_text = pattern.findall(text)[0] | |
text_lines = [] | |
for line in card_text.splitlines(): | |
line = line.replace('{T}', | |
'<i class="ms ms-tap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
line = line.replace('{UT}', | |
'<i class="ms ms-untap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
line = line.replace('{E}', | |
'<i class="ms ms-instant ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
line = re.sub(r"{(.*?)}", | |
r'<i class="ms ms-\1 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), | |
line) | |
line = re.sub(r"ms-(.)/(.)", | |
r'<i class="ms ms-\1\2 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), | |
line) | |
line = line.replace('(', '(<i>').replace(')', '</i>)') | |
text_lines.append(f"<p>{line}</p>") | |
template = template.replace("{card_text}", "\n".join(text_lines)) | |
pattern = re.compile(r'Flavor.?Text: (.*?)\n^.*$', re.MULTILINE | re.DOTALL) | |
flavor_text = pattern.findall(text) | |
if flavor_text: | |
flavor_text = flavor_text[0] | |
flavor_text_lines = [] | |
for line in flavor_text.splitlines(): | |
flavor_text_lines.append(f"<p>{line}</p>") | |
template = template.replace("{flavor_text}", "<blockquote>" + "\n".join(flavor_text_lines) + "</blockquote>") | |
else: | |
template = template.replace("{flavor_text}", "") | |
if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: | |
template = template.replace("{text_size}", '16') | |
template = template.replace( | |
'ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>', | |
'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;"></i>') | |
else: | |
template = template.replace("{text_size}", '18') | |
pattern = re.compile('Power: (.*)') | |
power = pattern.findall(text) | |
if power: | |
power = power[0] | |
if not power: | |
template = template.replace("{power_toughness}", "") | |
pattern = re.compile('Toughness: (.*)') | |
toughness = pattern.findall(text)[0] | |
template = template.replace("{power_toughness}", | |
f'<header class="powerToughness"><div><h2 style="font-family: \'Beleren\';font-size: 19px;">{power}/{toughness}</h2></div></header>') | |
else: | |
template = template.replace("{power_toughness}", "") | |
pathlib.Path("scratch.html").write_text(template, encoding='utf-8') | |
return template | |
def get_savename(directory, name, extension): | |
save_name = f"{name}.{extension}" | |
i = 1 | |
while os.path.exists(os.path.join(directory, save_name)): | |
save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" | |
i += 1 | |
return save_name | |
def html_to_png(card_name, html): | |
save_name = get_savename('rendered_cards', card_name, 'png') | |
print('CONVERTING HTML CARD TO PNG IMAGE') | |
path = os.path.join('rendered_cards', save_name) | |
try: | |
rendered_card_dir = 'rendered_cards' | |
hti = Html2Image(output_path=rendered_card_dir) | |
paths = hti.screenshot(html_str=html, | |
css_file=['./css/mtg_custom.css', './css/mana.css', | |
'./css/keyrune.css'], | |
save_as=save_name, size=(450, 600)) | |
print(paths) | |
path = paths[0] | |
except: | |
pass | |
print('OPENING IMAGE FROM FILE') | |
img = Image.open(path) | |
print('CROPPING BACKGROUND') | |
area = (0, 50, 400, 600) | |
cropped_img = img.crop(area) | |
cropped_img.resize((400, 550)) | |
cropped_img.save(os.path.join(path)) | |
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') | |
return cropped_img.convert('RGB') | |
def get_initial_card(): | |
return Image.open('SampleCard.png') | |
def pil_to_base64(image): | |
print('CONVERTING PIL IMAGE TO BASE64 STRING') | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') | |
return img_str | |
def generate_card(image: str, card_text: str): | |
image_data = pil_to_base64(Image.open(image)) | |
html = format_html(card_text, image_data) | |
pattern = re.compile('Name: (.*)') | |
name = pattern.findall(card_text)[0] | |
card = html_to_png(name, html) | |
return card | |
def transcribe(audio: str) -> (str, str): | |
result = whisper_pipe(audio) | |
return result["text"], None | |
starting_text = """Name: Wizards of the Coast | |
ManaCost: {0} | |
Type: Enchantment | |
Rarity: mythic rare | |
Text: At the beginning of your upkeep, reveal the top card of your library. If it's a card named "Magic: The Gathering", put it into your hand. Otherwise, put it into your graveyard. | |
FlavorText: "We are the guardians of the multiverse, and we will protect it at all costs." | |
Color: ['U']""" | |
def generate_image(card_text: str): | |
pattern = re.compile('Name: (.*)') | |
name = pattern.findall(card_text)[0] | |
pattern = re.compile('Type: (.*)') | |
card_type = pattern.findall(card_text)[0] | |
prompt = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" | |
print(f'Calling image generation with prompt: {prompt}') | |
try: | |
result = client.predict( | |
prompt, # str in 'parameter_5' Textbox component | |
0.3, # float (numeric value between 0.0 and 5) in 'Guidance' Slider component | |
4, # float (numeric value between 2 and 10) in 'Steps' Slider component | |
random.randint(0, 12013012031030), | |
# float (numeric value between 0 and 12013012031030) in 'Seed' Slider component | |
api_name="/predict" | |
) | |
print(result) | |
return result | |
except Exception as e: | |
print(f'Failed to generate image from client: {e}') | |
return 'placeholder.png' | |
def add_hotkeys() -> str: | |
return pathlib.Path("hotkeys.js").read_text() | |
with gr.Blocks(title='MagicGen') as demo: | |
gr.Markdown("# 🎴 MagicGenV2") | |
gr.Markdown("## Generate and Edit Magic the Gathering Cards with a Chat Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
audio_in = gr.Microphone(label="Record a voice request (click or press ctrl + ` to start/stop)", | |
type='filepath', elem_classes=["record-btn"]) | |
prompt_in = gr.Textbox(label="Or type a text request and press Enter", interactive=True, | |
placeholder="Need an idea? Try one of these:\n- Create a creature card named 'WiFi Elemental'\n- Make it an instant\n- Change the color") | |
with gr.Accordion(label='🤖 Chat Assistant Response', open=False): | |
bot_text = gr.TextArea(label='Response', interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
in_text = gr.TextArea(label="Card Text (Shift+Enter to submit)", value=starting_text) | |
gen_image_button = gr.Button('🖼️ Generate Card Image') | |
in_image = gr.Image(label="Card Image (400px x 550px)", type='filepath', value='placeholder.png') | |
render_button = gr.Button('🎴 Render Card', variant="primary") | |
gr.ClearButton([audio_in, prompt_in, in_text, in_image]) | |
with gr.Column(): | |
out_image = gr.Image(label="Rendered Card", value=get_initial_card()) | |
transcribe_params = {'fn': transcribe, 'inputs': [audio_in], 'outputs': [prompt_in, audio_in]} | |
generate_text_params = {'fn': generate_text, 'inputs': [in_text, prompt_in], | |
'outputs': [bot_text, in_text, audio_in]} | |
generate_image_params = {'fn': generate_image, 'inputs': [in_text], 'outputs': [in_image]} | |
generate_card_params = {'fn': generate_card, 'inputs': [in_image, in_text], 'outputs': [out_image]} | |
# Shift + Enter to submit text in TextAreas | |
audio_in.stop_recording(**transcribe_params).then(**generate_text_params).then(**generate_image_params).then( | |
**generate_card_params) | |
prompt_in.submit(**generate_text_params).then(**generate_image_params).then(**generate_card_params) | |
in_text.submit(**generate_card_params) | |
render_button.click(**generate_card_params) | |
gen_image_button.click(**generate_image_params).then(**generate_card_params) | |
demo.load(None, None, None, js=add_hotkeys()) | |
if __name__ == "__main__": | |
demo.queue().launch(favicon_path="favicon-96x96.png") | |