|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.hugging_face import ( |
|
HuggingFaceCTRLGenerator, |
|
HuggingFaceGenerationAlgorithm, |
|
HuggingFaceGPT2Generator, |
|
HuggingFaceTransfoXLGenerator, |
|
HuggingFaceOpenAIGPTGenerator, |
|
HuggingFaceXLMGenerator, |
|
HuggingFaceXLNetGenerator, |
|
) |
|
from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
MODEL_FN = { |
|
"HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator, |
|
"HuggingFaceGPT2Generator": HuggingFaceGPT2Generator, |
|
"HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator, |
|
"HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator, |
|
"HuggingFaceXLMGenerator": HuggingFaceXLMGenerator, |
|
"HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator, |
|
} |
|
|
|
|
|
def run_inference( |
|
model_type: str, |
|
prompt: str, |
|
length: float, |
|
temperature: float, |
|
prefix: str, |
|
k: float, |
|
p: float, |
|
repetition_penalty: float, |
|
): |
|
model = model_type.split("_")[0] |
|
version = model_type.split("_")[1] |
|
|
|
if model not in MODEL_FN.keys(): |
|
raise ValueError(f"Model type {model} not supported") |
|
config = MODEL_FN[model]( |
|
algorithm_version=version, |
|
prompt=prompt, |
|
length=length, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
k=k, |
|
p=p, |
|
prefix=prefix, |
|
) |
|
|
|
model = HuggingFaceGenerationAlgorithm(config) |
|
text = list(model.sample(1))[0] |
|
|
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_algos = ApplicationsRegistry.list_available() |
|
algos = [ |
|
x["algorithm_application"] + "_" + x["algorithm_version"] |
|
for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos)) |
|
] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
print("Examples: ", examples.values.tolist()) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="HuggingFace language models", |
|
inputs=[ |
|
gr.Dropdown( |
|
algos, |
|
label="Language model", |
|
value="HuggingFaceGPT2Generator_gpt2", |
|
), |
|
gr.Textbox( |
|
label="Text prompt", |
|
placeholder="I'm a stochastic parrot.", |
|
lines=1, |
|
), |
|
gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1), |
|
gr.Slider( |
|
minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature" |
|
), |
|
gr.Textbox( |
|
label="Prefix", placeholder="Some prefix (before the prompt)", lines=1 |
|
), |
|
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1), |
|
gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1), |
|
gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"), |
|
], |
|
outputs=gr.Textbox(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|