Spaces:
Running
Running
import argparse | |
import pprint as pp | |
import logging | |
import time | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from utils import make_mailto_form, postprocess, clear, make_email_link | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
use_gpu = torch.cuda.is_available() | |
def generate_text( | |
prompt: str, | |
gen_length=64, | |
penalty_alpha=0.6, | |
top_k=6, | |
length_penalty=1.0, | |
# perma params (not set by user) | |
abs_max_length=512, | |
verbose=False, | |
): | |
""" | |
generate_text - generate text using the text generation pipeline | |
:param str prompt: the prompt to use for the text generation pipeline | |
:param int gen_length: the number of tokens to generate | |
:param float penalty_alpha: the penalty alpha for the text generation pipeline (contrastive search) | |
:param int top_k: the top k for the text generation pipeline (contrastive search) | |
:param int abs_max_length: the absolute max length for the text generation pipeline | |
:param bool verbose: verbose output | |
:return str: the generated text | |
""" | |
global generator | |
if verbose: | |
logging.info(f"Generating text from prompt:\n\n{prompt}") | |
logging.info( | |
pp.pformat( | |
f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, length_penalty={length_penalty}" | |
) | |
) | |
st = time.perf_counter() | |
input_tokens = generator.tokenizer(prompt) | |
input_len = len(input_tokens["input_ids"]) | |
if input_len > abs_max_length: | |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") | |
result = generator( | |
prompt, | |
max_length=gen_length + input_len, # old API for generation | |
min_length=input_len + 4, | |
penalty_alpha=penalty_alpha, | |
top_k=top_k, | |
length_penalty=length_penalty, | |
) # generate | |
response = result[0]["generated_text"] | |
rt = time.perf_counter() - st | |
if verbose: | |
logging.info(f"Generated text: {response}") | |
rt_string = f"Generation time: {rt:.2f}s" | |
logging.info(rt_string) | |
formatted_email = postprocess(response) | |
return make_mailto_form(body=formatted_email), formatted_email | |
def load_emailgen_model(model_tag: str): | |
""" | |
load_emailgen_model - load a text generation pipeline for email generation | |
Args: | |
model_tag (str): the huggingface model tag to load | |
Returns: | |
transformers.pipelines.TextGenerationPipeline: the text generation pipeline | |
""" | |
global generator | |
generator = pipeline( | |
"text-generation", | |
model_tag, | |
device=0 if use_gpu else -1, | |
) | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
""" | |
parser = argparse.ArgumentParser( | |
description="Text Generation demo for postbot", | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
required=False, | |
type=str, | |
default="postbot/distilgpt2-emailgen-V2", | |
help="Pass an different huggingface model tag to use a custom model", | |
) | |
parser.add_argument( | |
"-l", | |
"--max_length", | |
required=False, | |
type=int, | |
default=40, | |
help="default max length of the generated text", | |
) | |
parser.add_argument( | |
"-a", | |
"--penalty_alpha", | |
type=float, | |
default=0.6, | |
help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6", | |
) | |
parser.add_argument( | |
"-k", | |
"--top_k", | |
type=int, | |
default=6, | |
help="The top k for the text generation pipeline (contrastive search) - default 6", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
required=False, | |
action="store_true", | |
help="Verbose output", | |
) | |
return parser | |
default_prompt = """ | |
Hello, | |
Following up on last week's bubblegum shipment, I""" | |
available_models = [ | |
"postbot/distilgpt2-emailgen-V2", | |
"postbot/distilgpt2-emailgen", | |
"postbot/gpt2-medium-emailgen", | |
"postbot/pythia-160m-hq-emails", | |
] | |
if __name__ == "__main__": | |
logging.info("\n\n\nStarting new instance of app.py") | |
args = get_parser().parse_args() | |
logging.info(f"received args:\t{args}") | |
model_tag = args.model | |
verbose = args.verbose | |
max_length = args.max_length | |
top_k = args.top_k | |
alpha = args.penalty_alpha | |
assert top_k > 0, "top_k must be greater than 0" | |
assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1" | |
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") | |
generator = pipeline( | |
"text-generation", | |
model_tag, | |
device=0 if use_gpu else -1, | |
) | |
demo = gr.Blocks() | |
logging.info("launching interface...") | |
with demo: | |
gr.Markdown("# Auto-Complete Emails - Demo") | |
gr.Markdown( | |
"Enter part of an email, and a text-gen model will complete it! See details below. " | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## Generate Text") | |
gr.Markdown("Edit the prompt and parameters and press **Generate**!") | |
prompt_text = gr.Textbox( | |
lines=4, | |
label="Email Prompt", | |
value=default_prompt, | |
) | |
with gr.Row(): | |
clear_button = gr.Button( | |
value="Clear Prompt", | |
) | |
num_gen_tokens = gr.Slider( | |
label="Generation Tokens", | |
value=max_length, | |
maximum=96, | |
minimum=16, | |
step=8, | |
) | |
generate_button = gr.Button( | |
value="Generate!", | |
variant="primary", | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Results") | |
# put a large HTML placeholder here | |
generated_email = gr.Textbox( | |
label="Generated Text", | |
placeholder="This is where the generated text will appear", | |
interactive=False, | |
) | |
email_mailto_button = gr.HTML( | |
"<i>a clickable email button will appear here</i>" | |
) | |
gr.Markdown("---") | |
gr.Markdown("## Advanced Options") | |
gr.Markdown( | |
"This demo generates text via the new [contrastive search](https://huggingface.co/blog/introducing-csearch). See the csearch blog post for details on the parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding." | |
) | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
choices=available_models, | |
label="Choose a model", | |
value=model_tag, | |
) | |
load_model_button = gr.Button( | |
"Load Model", | |
variant="secondary", | |
) | |
with gr.Row(): | |
contrastive_top_k = gr.Radio( | |
choices=[2, 4, 6, 8], | |
label="Top K", | |
value=top_k, | |
) | |
penalty_alpha = gr.Slider( | |
label="Penalty Alpha", | |
value=alpha, | |
maximum=1.0, | |
minimum=0.0, | |
step=0.1, | |
) | |
length_penalty = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
label="Length Penalty", | |
value=1.0, | |
step=0.1, | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## About") | |
gr.Markdown( | |
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." | |
) | |
gr.Markdown( | |
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." | |
) | |
gr.Markdown("---") | |
clear_button.click( | |
fn=clear, | |
inputs=[prompt_text], | |
outputs=[prompt_text], | |
) | |
generate_button.click( | |
fn=generate_text, | |
inputs=[ | |
prompt_text, | |
num_gen_tokens, | |
penalty_alpha, | |
contrastive_top_k, | |
length_penalty, | |
], | |
outputs=[email_mailto_button, generated_email], | |
) | |
load_model_button.click( | |
fn=load_emailgen_model, | |
inputs=[model_name], | |
outputs=[], | |
) | |
demo.launch( | |
enable_queue=True, | |
share=True, # for local testing | |
) | |