|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.hugging_face import ( |
|
HuggingFaceSeq2SeqGenerator, |
|
HuggingFaceGenerationAlgorithm |
|
) |
|
from transformers import AutoTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
def run_inference( |
|
model_name_or_path: str, |
|
prefix: str, |
|
prompt: str, |
|
num_beams: int, |
|
): |
|
|
|
config = HuggingFaceSeq2SeqGenerator( |
|
algorithm_version=model_name_or_path, |
|
prefix=prefix, |
|
prompt=prompt, |
|
num_beams=num_beams |
|
) |
|
|
|
model = HuggingFaceGenerationAlgorithm(config) |
|
tokenizer = AutoTokenizer.from_pretrained("t5-small") |
|
|
|
text = list(model.sample(1))[0] |
|
|
|
text = text.replace(prefix+prompt,"") |
|
text = text.split(tokenizer.eos_token)[0] |
|
text = text.replace(tokenizer.pad_token, "") |
|
text = text.strip() |
|
|
|
|
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm", |
|
"text-chem-t5-base-standard", "text-chem-t5-base-augm"] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
print("Examples: ", examples.values.tolist()) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="Text-chem-T5 model", |
|
inputs=[ |
|
gr.Dropdown( |
|
models, |
|
label="Language model", |
|
value="text-chem-t5-base-augm", |
|
), |
|
gr.Textbox( |
|
label="Prefix", placeholder="A task-specific prefix", lines=1 |
|
), |
|
gr.Textbox( |
|
label="Text prompt", |
|
placeholder="I'm a stochastic parrot.", |
|
lines=1, |
|
), |
|
gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1), |
|
], |
|
outputs=gr.Textbox(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|