Spaces:
Runtime error
Runtime error
File size: 10,735 Bytes
e126020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import math
import os
import sys
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torchvision
os.system("git clone https://github.com/xplip/pixel.git")
sys.path.append('./pixel')
from transformers import set_seed
from pixel.src.pixel import (
PIXELConfig,
PIXELForPreTraining,
SpanMaskingGenerator,
PyGameTextRenderer,
get_transforms,
resize_model_embeddings,
truncate_decoder_pos_embeddings,
get_attention_mask
)
model_name_or_path = "Team-PIXEL/pixel-base"
max_seq_length = 529
text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length)
config = PIXELConfig.from_pretrained(model_name_or_path)
model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config)
def clip(x: torch.Tensor):
x = torch.einsum("chw->hwc", x)
x = torch.clip(x * 255, 0, 255)
x = torch.einsum("hwc->chw", x)
return x
def get_image(img: torch.Tensor, do_clip: bool = True):
if do_clip:
img = clip(img)
img = torchvision.utils.make_grid(img, normalize=True)
image = Image.fromarray(
img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
)
return image
def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42):
config.update({"mask_ratio": mask_ratio})
resize_model_embeddings(model, max_seq_length)
truncate_decoder_pos_embeddings(model, max_seq_length)
set_seed(seed)
transforms = get_transforms(
do_resize=True,
size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length),
)
encoding = text_renderer(text=text)
attention_mask = get_attention_mask(
num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length
)
img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
inputs = {"pixel_values": img.float(), "attention_mask": attention_mask}
mask_generator = SpanMaskingGenerator(
num_patches=text_renderer.max_seq_length,
num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length),
max_span_length=max_span_length,
spacing="span"
)
mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0)
inputs.update({"patch_mask": mask})
model.eval()
with torch.no_grad():
outputs = model(**inputs)
predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze()
mask = outputs["mask"].detach().cpu()
mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
mask = model.unpatchify(mask).squeeze() # 1 is removing, 0 is keeping
attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
attention_mask = model.unpatchify(attention_mask).squeeze()
original_img = model.unpatchify(model.patchify(img)).squeeze()
im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long())
masked_predictions = predictions * mask * attention_mask
reconstruction = im_masked + masked_predictions
return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)]
examples = [
["Penguins are designed to be streamlined and hydrodynamic, so having long legs would add extra drag. Having short legs with webbed feet to act like rudders, helps to give them that torpedo-like figure. If we compare bird anatomy with humans, we would see something a bit peculiar. By taking a look at the side-by-side image in Figure 1, you can see how their leg bones compare to ours. What most people mistake for knees are actually the ankles of the birds. This gives the illusion that bird knees bend opposite of ours. The knees are actually tucked up inside the body cavity of the bird! So how does this look inside of a penguin? In the images below, you can see boxes surrounding the penguins’ knees.", 0.2, 6, 42],
["Félicette didn’t seem like a typical astronaut. She weighed just five and a half pounds. She’d spent most of her life on the streets of Paris. And Félicette was a cat, one of 14 trained by French scientists for space flight. In 1963, she went where no feline had gone before. Chosen for her calm demeanor and low weight, Félicette was strapped into a rocket in October of that year. She spent 15 minutes on a dizzying flight to the stars before returning safely to earth. Her legacy, however, has been largely forgotten. While other space animals like Laika the dog and Ham the chimp have been celebrated, Félicette became a footnote of history. This is the story of the only cat to go to space.", 0.25, 4, 42],
["In many, many ways, fish of the species Brienomyrus brachyistius do not speak at all like Barack Obama. For starters, they communicate not through a spoken language but through electrical pulses booped out by specialized organs found near the tail. Their vocabulary is also quite unpresidentially poor, with each individual capable of producing just one electric wave—a unique but monotonous signal. “It’s even simpler than Morse code,” Bruce Carlson, a biologist at Washington University in St. Louis who studies Brienomyrus fish, told me. In at least one significant way, though, fish of the species Brienomyrus brachyistius do speak a little bit like Barack Obama. When they want to send an important message… They stop, just for a moment. Those gaps tend to occur in very particular patterns, right before fishy phrases and sentences with “high-information content” about property, say, or courtship, Carlson said. Electric fish have, like the former president, mastered the art of the dramatic pause—a rhetorical trick that can help listeners cue in more strongly to what speakers have to say next, Carlson and his colleagues report in a study published today in Current Biology.", 0.5, 1, 42],
]
placeholder_text = "Our message is simple. Because we truly believe in our peanut-loving hearts that peanuts make everything better. Peanuts are perfectly powerful because they're packed with nutrition and they bring people together. Our thirst for peanut knowledge is unquenchable, so we’re always sharing snackable news stories about the benefits of peanuts, recent stats, research, etc. Our passion for peanuts is infectious. We root for peanuts as if they were a home run away from winning it all. We care about peanuts and the people who grow them. We give shout-outs to those who lift up and promote peanuts and the peanut story. We’re an authority on peanuts and we're anything but boring."
demo = gr.Blocks(css="#output_image {width: auto; display: block; margin-left: auto; margin-right: auto;} #button {display: block; margin: 0 auto;}")
with demo:
gr.Markdown("## PIXEL Masked Autoencoding")
gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.")
with gr.Row():
with gr.Column():
tb_text = gr.Textbox(
lines=1,
label="Text",
placeholder=placeholder_text)
sl_ratio = gr.Slider(
minimum=0.01,
maximum=1.0,
step=0.01,
value=0.25,
label="Span masking ratio",
)
sl_len = gr.Slider(
minimum=1,
maximum=6,
step=1,
value=6,
label="Masking max span length",
)
sl_seed = gr.Slider(
minimum=0,
maximum=1000,
step=1,
value=42,
label="Random seed"
)
with gr.Box().style(rounded=False):
btn = gr.Button("Run", variant="primary", elem_id="button")
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Box().style(rounded=False):
gr.Markdown("**Original**")
out_original = gr.Image(
type="pil",
label="Original",
show_label=False,
elem_id="output_image"
)
with gr.Box().style(rounded=False):
gr.Markdown("**Masked Predictions**")
out_masked_pred = gr.Image(
type="pil",
label="Masked Predictions",
show_label=False,
elem_id="output_image"
)
with gr.Column():
with gr.Box().style(rounded=False):
gr.Markdown("**Masked**")
out_masked = gr.Image(
type="pil",
label="Masked",
show_label=False,
elem_id="output_image"
)
with gr.Box().style(rounded=False):
gr.Markdown("**Reconstruction**")
out_reconstruction = gr.Image(
type="pil",
label="Reconstruction",
show_label=False,
elem_id="output_image"
)
with gr.Row():
with gr.Box().style(rounded=False):
gr.Markdown("### Examples")
gr_examples = gr.Examples(
examples,
inputs=[tb_text, sl_ratio, sl_len, sl_seed],
outputs=[out_original, out_masked, out_masked_pred, out_reconstruction],
fn=inference,
cache_examples=True
)
gr.HTML("<p style='text-align: center'><a href='https://arxiv.org/abs/2207.06991' target='_blank'><b>Paper</b></a> | <a href='https://github.com/xplip/pixel' target='_blank'><b>Github</b></a></p>")
gr.HTML("<center><img src='https://visitor-badge.glitch.me/badge?page_id=Team-PIXEL/PIXEL' alt='visitor badge'></center>")
btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction])
demo.launch(debug=True)
|