Spaces:
Runtime error
Runtime error
import math | |
import os | |
import sys | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision | |
os.system("git clone https://github.com/xplip/pixel.git") | |
sys.path.append('./pixel') | |
from transformers import set_seed | |
from pixel.src.pixel import ( | |
PIXELConfig, | |
PIXELForPreTraining, | |
SpanMaskingGenerator, | |
PyGameTextRenderer, | |
get_transforms, | |
resize_model_embeddings, | |
truncate_decoder_pos_embeddings, | |
get_attention_mask | |
) | |
model_name_or_path = "Team-PIXEL/pixel-base" | |
max_seq_length = 529 | |
text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length) | |
config = PIXELConfig.from_pretrained(model_name_or_path) | |
model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config) | |
def clip(x: torch.Tensor): | |
x = torch.einsum("chw->hwc", x) | |
x = torch.clip(x * 255, 0, 255) | |
x = torch.einsum("hwc->chw", x) | |
return x | |
def get_image(img: torch.Tensor, do_clip: bool = True): | |
if do_clip: | |
img = clip(img) | |
img = torchvision.utils.make_grid(img, normalize=True) | |
image = Image.fromarray( | |
img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() | |
) | |
return image | |
def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42): | |
config.update({"mask_ratio": mask_ratio}) | |
resize_model_embeddings(model, max_seq_length) | |
truncate_decoder_pos_embeddings(model, max_seq_length) | |
set_seed(seed) | |
transforms = get_transforms( | |
do_resize=True, | |
size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length), | |
) | |
encoding = text_renderer(text=text) | |
attention_mask = get_attention_mask( | |
num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length | |
) | |
img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0) | |
attention_mask = attention_mask.unsqueeze(0) | |
inputs = {"pixel_values": img.float(), "attention_mask": attention_mask} | |
mask_generator = SpanMaskingGenerator( | |
num_patches=text_renderer.max_seq_length, | |
num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length), | |
max_span_length=max_span_length, | |
spacing="span" | |
) | |
mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0) | |
inputs.update({"patch_mask": mask}) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze() | |
mask = outputs["mask"].detach().cpu() | |
mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) | |
mask = model.unpatchify(mask).squeeze() # 1 is removing, 0 is keeping | |
attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3) | |
attention_mask = model.unpatchify(attention_mask).squeeze() | |
original_img = model.unpatchify(model.patchify(img)).squeeze() | |
im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long()) | |
masked_predictions = predictions * mask * attention_mask | |
reconstruction = im_masked + masked_predictions | |
return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)] | |
examples = [ | |
["Penguins are designed to be streamlined and hydrodynamic, so having long legs would add extra drag. Having short legs with webbed feet to act like rudders, helps to give them that torpedo-like figure. If we compare bird anatomy with humans, we would see something a bit peculiar. By taking a look at the side-by-side image in Figure 1, you can see how their leg bones compare to ours. What most people mistake for knees are actually the ankles of the birds. This gives the illusion that bird knees bend opposite of ours. The knees are actually tucked up inside the body cavity of the bird! So how does this look inside of a penguin? In the images below, you can see boxes surrounding the penguins’ knees.", 0.2, 6, 42], | |
["Félicette didn’t seem like a typical astronaut. She weighed just five and a half pounds. She’d spent most of her life on the streets of Paris. And Félicette was a cat, one of 14 trained by French scientists for space flight. In 1963, she went where no feline had gone before. Chosen for her calm demeanor and low weight, Félicette was strapped into a rocket in October of that year. She spent 15 minutes on a dizzying flight to the stars before returning safely to earth. Her legacy, however, has been largely forgotten. While other space animals like Laika the dog and Ham the chimp have been celebrated, Félicette became a footnote of history. This is the story of the only cat to go to space.", 0.25, 4, 42], | |
["In many, many ways, fish of the species Brienomyrus brachyistius do not speak at all like Barack Obama. For starters, they communicate not through a spoken language but through electrical pulses booped out by specialized organs found near the tail. Their vocabulary is also quite unpresidentially poor, with each individual capable of producing just one electric wave—a unique but monotonous signal. “It’s even simpler than Morse code,” Bruce Carlson, a biologist at Washington University in St. Louis who studies Brienomyrus fish, told me. In at least one significant way, though, fish of the species Brienomyrus brachyistius do speak a little bit like Barack Obama. When they want to send an important message… They stop, just for a moment. Those gaps tend to occur in very particular patterns, right before fishy phrases and sentences with “high-information content” about property, say, or courtship, Carlson said. Electric fish have, like the former president, mastered the art of the dramatic pause—a rhetorical trick that can help listeners cue in more strongly to what speakers have to say next, Carlson and his colleagues report in a study published today in Current Biology.", 0.5, 1, 42], | |
] | |
placeholder_text = "Our message is simple. Because we truly believe in our peanut-loving hearts that peanuts make everything better. Peanuts are perfectly powerful because they're packed with nutrition and they bring people together. Our thirst for peanut knowledge is unquenchable, so we’re always sharing snackable news stories about the benefits of peanuts, recent stats, research, etc. Our passion for peanuts is infectious. We root for peanuts as if they were a home run away from winning it all. We care about peanuts and the people who grow them. We give shout-outs to those who lift up and promote peanuts and the peanut story. We’re an authority on peanuts and we're anything but boring." | |
demo = gr.Blocks(css="#output_image {width: auto; display: block; margin-left: auto; margin-right: auto;} #button {display: block; margin: 0 auto;}") | |
with demo: | |
gr.Markdown("## PIXEL Masked Autoencoding") | |
gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.") | |
with gr.Row(): | |
with gr.Column(): | |
tb_text = gr.Textbox( | |
lines=1, | |
label="Text", | |
placeholder=placeholder_text) | |
sl_ratio = gr.Slider( | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.25, | |
label="Span masking ratio", | |
) | |
sl_len = gr.Slider( | |
minimum=1, | |
maximum=6, | |
step=1, | |
value=6, | |
label="Masking max span length", | |
) | |
sl_seed = gr.Slider( | |
minimum=0, | |
maximum=1000, | |
step=1, | |
value=42, | |
label="Random seed" | |
) | |
with gr.Box().style(rounded=False): | |
btn = gr.Button("Run", variant="primary", elem_id="button") | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box().style(rounded=False): | |
gr.Markdown("**Original**") | |
out_original = gr.Image( | |
type="pil", | |
label="Original", | |
show_label=False, | |
elem_id="output_image" | |
) | |
with gr.Box().style(rounded=False): | |
gr.Markdown("**Masked Predictions**") | |
out_masked_pred = gr.Image( | |
type="pil", | |
label="Masked Predictions", | |
show_label=False, | |
elem_id="output_image" | |
) | |
with gr.Column(): | |
with gr.Box().style(rounded=False): | |
gr.Markdown("**Masked**") | |
out_masked = gr.Image( | |
type="pil", | |
label="Masked", | |
show_label=False, | |
elem_id="output_image" | |
) | |
with gr.Box().style(rounded=False): | |
gr.Markdown("**Reconstruction**") | |
out_reconstruction = gr.Image( | |
type="pil", | |
label="Reconstruction", | |
show_label=False, | |
elem_id="output_image" | |
) | |
with gr.Row(): | |
with gr.Box().style(rounded=False): | |
gr.Markdown("### Examples") | |
gr_examples = gr.Examples( | |
examples, | |
inputs=[tb_text, sl_ratio, sl_len, sl_seed], | |
outputs=[out_original, out_masked, out_masked_pred, out_reconstruction], | |
fn=inference, | |
cache_examples=True | |
) | |
gr.HTML("<p style='text-align: center'><a href='https://arxiv.org/abs/2207.06991' target='_blank'><b>Paper</b></a> | <a href='https://github.com/xplip/pixel' target='_blank'><b>Github</b></a></p>") | |
gr.HTML("<center><img src='https://visitor-badge.glitch.me/badge?page_id=Team-PIXEL/PIXEL' alt='visitor badge'></center>") | |
btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction]) | |
demo.launch(debug=True) | |