pszemraj's picture
πŸ“ add details on csearch
0508364
raw
history blame
9.37 kB
import argparse
import logging
import time
import gradio as gr
import torch
from transformers import pipeline
from utils import make_mailto_form, postprocess, clear, make_email_link
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
use_gpu = torch.cuda.is_available()
def generate_text(
prompt: str,
gen_length=64,
penalty_alpha=0.6,
top_k=6,
no_repeat_ngram_size=2,
length_penalty=1.0,
# perma params (not set by user)
abs_max_length=512,
verbose=False,
):
"""
generate_text - generate text from a prompt using a text generation pipeline
Args:
prompt (str): the prompt to generate text from
model_input (_type_): the text generation pipeline
max_length (int, optional): the maximum length of the generated text. Defaults to 128.
method (str, optional): the generation method. Defaults to "Sampling".
verbose (bool, optional): the verbosity of the output. Defaults to False.
Returns:
str: the generated text
"""
global generator
if verbose:
logging.info(f"Generating text from prompt:\n\n{prompt}")
logging.info(
f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}"
)
st = time.perf_counter()
input_tokens = generator.tokenizer(prompt)
input_len = len(input_tokens["input_ids"])
if input_len > abs_max_length:
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
result = generator(
prompt,
max_length=gen_length + input_len,
min_length=input_len + 4,
penalty_alpha=penalty_alpha,
top_k=top_k,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
) # generate
response = result[0]["generated_text"]
rt = time.perf_counter() - st
if verbose:
logging.info(f"Generated text: {response}")
rt_string = f"Generation time: {rt:.2f}s"
logging.info(rt_string)
formatted_email = postprocess(response)
return make_mailto_form(body=formatted_email), formatted_email
def load_emailgen_model(model_tag: str):
"""
load_emailgen_model - load a text generation pipeline for email generation
Args:
model_tag (str): the huggingface model tag to load
Returns:
transformers.pipelines.TextGenerationPipeline: the text generation pipeline
"""
global generator
generator = pipeline(
"text-generation",
model_tag,
device=0 if use_gpu else -1,
)
def get_parser():
"""
get_parser - a helper function for the argparse module
"""
parser = argparse.ArgumentParser(
description="Text Generation demo for postbot",
)
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="postbot/distilgpt2-emailgen-V2",
help="Pass an different huggingface model tag to use a custom model",
)
parser.add_argument(
"-v",
"--verbose",
required=False,
action="store_true",
help="Verbose output",
)
parser.add_argument(
"-a",
"--penalty_alpha",
type=float,
default=0.6,
help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
)
parser.add_argument(
"-k",
"--top_k",
type=int,
default=6,
help="The top k for the text generation pipeline (contrastive search) - default 6",
)
return parser
default_prompt = """
Hello,
Following up on last week's bubblegum shipment, I"""
available_models = [
"postbot/distilgpt2-emailgen-V2",
"postbot/distilgpt2-emailgen",
"postbot/gpt2-medium-emailgen",
]
if __name__ == "__main__":
logging.info("\n\n\nStarting new instance of app.py")
args = get_parser().parse_args()
logging.info(f"received args:\t{args}")
model_tag = args.model
verbose = args.verbose
top_k = args.top_k
alpha = args.penalty_alpha
assert top_k > 0, "top_k must be greater than 0"
assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
generator = pipeline(
"text-generation",
model_tag,
device=0 if use_gpu else -1,
)
demo = gr.Blocks()
logging.info("launching interface...")
with demo:
gr.Markdown("# Auto-Complete Emails - Demo")
gr.Markdown(
"Enter part of an email, and a text-gen model will complete it! See details below. "
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## Generate Text")
gr.Markdown("Edit the prompt and parameters and press **Generate**!")
prompt_text = gr.Textbox(
lines=4,
label="Email Prompt",
value=default_prompt,
)
with gr.Row():
clear_button = gr.Button(
value="Clear Prompt",
)
num_gen_tokens = gr.Slider(
label="Generation Tokens",
value=32,
maximum=96,
minimum=16,
step=8,
)
generate_button = gr.Button(
value="Generate!",
variant="primary",
)
gr.Markdown("---")
gr.Markdown("### Results")
# put a large HTML placeholder here
generated_email = gr.Textbox(
label="Generated Text",
placeholder="This is where the generated text will appear",
interactive=False,
)
email_mailto_button = gr.HTML(
"<i>a clickable email button will appear here</i>"
)
gr.Markdown("---")
gr.Markdown("## Advanced Options")
gr.Markdown(
"This demo generates text via the new [constrastive search](https://huggingface.co/blog/introducing-csearch). See details on the csearch blog post for the methods' parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
)
with gr.Row():
model_name = gr.Dropdown(
choices=available_models,
label="Choose a model",
value=model_tag,
)
load_model_button = gr.Button(
"Load Model",
variant="secondary",
)
no_repeat_ngram_size = gr.Radio(
choices=[1, 2, 3, 4],
label="no repeat ngram size",
value=3,
)
with gr.Row():
contrastive_top_k = gr.Radio(
choices=[2, 4, 6, 8],
label="Top K",
value=top_k,
)
penalty_alpha = gr.Slider(
label="Penalty Alpha",
value=alpha,
maximum=1.0,
minimum=0.0,
step=0.1,
)
length_penalty = gr.Slider(
minimum=0.5,
maximum=1.0,
label="Length Penalty",
value=1.0,
step=0.1,
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## About")
gr.Markdown(
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
)
gr.Markdown(
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
)
gr.Markdown("---")
clear_button.click(
fn=clear,
inputs=[prompt_text],
outputs=[prompt_text],
)
generate_button.click(
fn=generate_text,
inputs=[
prompt_text,
num_gen_tokens,
penalty_alpha,
contrastive_top_k,
no_repeat_ngram_size,
length_penalty,
],
outputs=[email_mailto_button, generated_email],
)
load_model_button.click(
fn=load_emailgen_model,
inputs=[model_name],
outputs=[],
)
demo.launch(
enable_queue=True,
share=True, # for local testing
)