neggles's picture
ok try this
c17e2d8
import html
import logging
from pathlib import Path
import gradio as gr
from gradio.themes.utils import colors
from transformers import CLIPTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
gr_logger = logging.getLogger("gradio")
gr_logger.setLevel(logging.INFO)
class ClipUtil:
def __init__(self):
logger.info("Loading ClipUtil")
self.theme = gr.themes.Base(
primary_hue=colors.violet,
secondary_hue=colors.indigo,
neutral_hue=colors.slate,
font=[gr.themes.GoogleFont("Fira Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("Fira Code"), "ui-monospace", "Consolas", "monospace"],
).set(
slider_color_dark="*primary_500",
)
try:
self.css = Path(__file__).with_suffix(".css").read_text()
except Exception:
logger.exception("Failed to load CSS file")
self.css = ""
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
self.blocks = gr.Blocks(
title="ClipTokenizerUtil", analytics_enabled=False, theme=self.theme, css=self.css
)
def tokenize(self, text: str, input_ids: bool = False):
if input_ids:
tokens = [int(x.strip()) for x in text.split(",")]
else:
tokens = self.tokenizer(text, return_tensors="np").input_ids.squeeze().tolist()
code = ""
ids = []
current_ids = []
class_index = 0
byte_decoder = self.tokenizer.byte_decoder
def dump(last=False):
nonlocal code, ids, current_ids
words = [self.vocab.get(x, "") for x in current_ids]
def wordscode(ids, word):
nonlocal class_index
word_title = html.escape(", ".join([str(x) for x in ids]))
res = f"""
<span class='tokenizer-token tokenizer-token-{class_index % 4}' title='{word_title}'>
{html.escape(word)}
</span>
"""
class_index += 1
return res
try:
word = bytearray([byte_decoder[x] for x in "".join(words)]).decode("utf-8")
except UnicodeDecodeError:
if last:
word = "❌" * len(current_ids)
elif len(current_ids) > 4:
id = current_ids[0]
ids += [id]
local_ids = current_ids[1:]
code += wordscode([id], "❌")
current_ids = []
for id in local_ids:
current_ids.append(id)
dump()
return
else:
return
# word = word.replace("</w>", " ")
code += wordscode(current_ids, word)
ids += current_ids
current_ids = []
for token in tokens:
token = int(token)
current_ids.append(token)
dump()
dump(last=True)
ids_html = f"""
<p>
Token count: {len(ids)}
<br>
{", ".join([str(x) for x in ids])}
</p>"""
return code, ids_html
def tokenize_ids(self, text: str):
return self.tokenize(text, input_ids=True)
def create_components(self):
with self.blocks:
# title bar
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="header_col"):
self.header_title = gr.Markdown(
"## CLIP Tokenizer Util",
elem_id="header_title",
)
with gr.Column(scale=1, min_width=90, elem_id="button_col"):
with gr.Row(elem_id="button_row"):
self.reload_btn = gr.Button(
label="refresh",
elem_id="refresh_btn",
type="button",
value="πŸ”„",
variant="primary",
)
with gr.Tabs() as in_tabs:
with gr.Tab(label="Text Input", id="text_input_tab"):
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="text_input_col"):
self.text_input = gr.Textbox(
label="Text Input",
elem_id="tokenizer_prompt",
show_label=False,
lines=8,
placeholder="Prompt for tokenization",
)
self.text_button = gr.Button(
label="Tokenize",
elem_id="go_button",
value="Go",
variant="primary",
)
with gr.Tab(label="Token Input", id="token_input_tab"):
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="text_input_col"):
self.token_input = gr.Textbox(
lines=5,
label="Text Input",
elem_id="text_input",
placeholder="Enter text here",
)
self.token_button = gr.Button(
label="Tokenize",
elem_id="go_button",
type="button",
value="Go",
variant="primary",
)
with gr.Tabs():
with gr.Tab("Text"):
tokenized_text = gr.HTML(elem_id="tokenized_text")
with gr.Tab("Tokens"):
tokenized_ids = gr.HTML(elem_id="tokenized_ids")
self.text_button.click(
fn=self.tokenize,
inputs=[self.text_input],
outputs=[tokenized_text, tokenized_ids],
)
self.token_button.click(
fn=self.tokenize_ids,
inputs=[self.token_input],
outputs=[tokenized_text, tokenized_ids],
)
def launch(self, **kwargs) -> None:
return self.blocks.launch(
server_name="0.0.0.0",
show_error=True,
enable_queue=True,
**kwargs,
)
if __name__ == "__main__":
clip_util = ClipUtil()
clip_util.create_components()
clip_util.launch()