import gradio as gr import json import math from backend import get_message_single, get_message_spam, send_single, send_spam, tokenizer from defaults import ( ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA, defaults_bt_single, defaults_bt_spam, defaults_vanilla_single, defaults_vanilla_spam, BATCH_SIZE, ) import datasets import torch def dispatch_single(input_model_single, address_input_vanilla, address_input_bettertransformer): result_vanilla = send_single(input_model_single, address_input_vanilla) result_bettertransformer = send_single(input_model_single, address_input_bettertransformer) return result_vanilla, result_bettertransformer def dispatch_spam(input_n_spam, address_input_vanilla, address_input_bettertransformer): input_n_spam = int(input_n_spam) assert input_n_spam <= len(data) inp = data.shuffle().select(range(input_n_spam)) result_vanilla = send_spam(inp, address_input_vanilla) result_bettertransformer = send_spam(inp, address_input_bettertransformer) return result_vanilla, result_bettertransformer def dispatch_spam_artif(input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer): sequence_length = int(sequence_length) input_n_spam_artif = int(input_n_spam_artif) inp_tokens = torch.randint(tokenizer.vocab_size - 1, (sequence_length,)) + 1 n_pads = max(int(padding_ratio * len(inp_tokens)), 1) inp_tokens[- n_pads:] = 0 inp_tokens[0] = 101 inp_tokens[- n_pads - 1] = 102 #inp_tokens = inp_tokens.unsqueeze(0).repeat(BATCH_SIZE, 1) attention_mask = torch.zeros((sequence_length,), dtype=torch.int64) attention_mask[:- n_pads] = 1 str_input = json.dumps({ "input_ids": inp_tokens.cpu().tolist(), "attention_mask": attention_mask.cpu().tolist(), "pre_tokenized": True, }) input_dataset = datasets.Dataset.from_dict( {"sentence": [str_input for _ in range(input_n_spam_artif)]} ) result_vanilla = send_spam(input_dataset, address_input_vanilla) result_bettertransformer = send_spam(input_dataset, address_input_bettertransformer) return result_vanilla, result_bettertransformer TTILE_IMAGE = """
""" TITLE = """

Speed up your inference and support more workload with PyTorch's BetterTransformer 🤗

""" with gr.Blocks() as demo: gr.HTML(TTILE_IMAGE) gr.HTML(TITLE) gr.Markdown( """ Let's try out TorchServe + BetterTransformer! BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the [🤗 Optimum](https://huggingface.co/docs/optimum/main/en/index) library: ``` from optimum.bettertransformer import BetterTransformer better_model = BetterTransformer.transform(model) ``` This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. ## Inference using... """ ) with gr.Row(): with gr.Column(scale=50): gr.Markdown("### Vanilla Transformers + TorchServe") with gr.Column(scale=50): gr.Markdown("### BetterTransformer + TorchServe") address_input_vanilla = gr.Textbox( max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False ) address_input_bettertransformer = gr.Textbox( max_lines=1, label="ip bettertransformer", value=ADDRESS_BETTERTRANSFORMER, visible=False, ) input_model_single = gr.Textbox( max_lines=1, label="Text", value="Expectations were low, enjoyment was high", ) btn_single = gr.Button("Send single text request") with gr.Row(): with gr.Column(scale=50): output_single_vanilla = gr.Markdown( label="Output single vanilla", value=get_message_single(**defaults_vanilla_single), ) with gr.Column(scale=50): output_single_bt = gr.Markdown( label="Output single bt", value=get_message_single(**defaults_bt_single) ) btn_single.click( fn=dispatch_single, inputs=[input_model_single, address_input_vanilla, address_input_bettertransformer], outputs=[output_single_vanilla, output_single_bt], ) input_n_spam_artif = gr.Number( label="Number of inputs to send", value=8, ) sequence_length = gr.Number( label="Sequence length (in tokens)", value=128, ) padding_ratio = gr.Number( label="Padding ratio", value=0.5, ) btn_spam_artif = gr.Button( "Spam text requests (using artificial data)" ) with gr.Row(): with gr.Column(scale=50): output_spam_vanilla_artif = gr.Markdown( label="Output spam vanilla", value=get_message_spam(**defaults_vanilla_spam), ) with gr.Column(scale=50): output_spam_bt_artif = gr.Markdown( label="Output spam bt", value=get_message_spam(**defaults_bt_spam) ) btn_spam_artif.click( fn=dispatch_spam_artif, inputs=[input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer], outputs=[output_spam_vanilla_artif, output_spam_bt_artif], ) demo.queue(concurrency_count=1) demo.launch()