christofid's picture
Duplicate from GT4SD/hf-transformers
2e605bf
raw
history blame
3.56 kB
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
HuggingFaceCTRLGenerator,
HuggingFaceGenerationAlgorithm,
HuggingFaceGPT2Generator,
HuggingFaceTransfoXLGenerator,
HuggingFaceOpenAIGPTGenerator,
HuggingFaceXLMGenerator,
HuggingFaceXLNetGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MODEL_FN = {
"HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator,
"HuggingFaceGPT2Generator": HuggingFaceGPT2Generator,
"HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator,
"HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator,
"HuggingFaceXLMGenerator": HuggingFaceXLMGenerator,
"HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator,
}
def run_inference(
model_type: str,
prompt: str,
length: float,
temperature: float,
prefix: str,
k: float,
p: float,
repetition_penalty: float,
):
model = model_type.split("_")[0]
version = model_type.split("_")[1]
if model not in MODEL_FN.keys():
raise ValueError(f"Model type {model} not supported")
config = MODEL_FN[model](
algorithm_version=version,
prompt=prompt,
length=length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
)
model = HuggingFaceGenerationAlgorithm(config)
text = list(model.sample(1))[0]
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_application"] + "_" + x["algorithm_version"]
for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="HuggingFace language models",
inputs=[
gr.Dropdown(
algos,
label="Language model",
value="HuggingFaceGPT2Generator_gpt2",
),
gr.Textbox(
label="Text prompt",
placeholder="I'm a stochastic parrot.",
lines=1,
),
gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1),
gr.Slider(
minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature"
),
gr.Textbox(
label="Prefix", placeholder="Some prefix (before the prompt)", lines=1
),
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1),
gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)