File size: 3,064 Bytes
2e605bf 7198503 89e8857 2e605bf 7198503 2e605bf 89e8857 58d00bc 89e8857 2e605bf 7198503 89e8857 7198503 2e605bf 89e8857 2e605bf 7198503 89e8857 7198503 89e8857 2e605bf 7198503 2e605bf 89e8857 7198503 2e605bf 89e8857 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 89e8857 2e605bf 7198503 2e605bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
HuggingFaceSeq2SeqGenerator,
HuggingFaceGenerationAlgorithm,
)
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
task2prefix = {
"forward": "Predict the product of the following reaction: ",
"retrosynthesis": "Predict the reaction that produces the following product: ",
"paragraph to actions": "Which actions are described in the following paragraph: ",
"molecular captioning": "Caption the following smile: ",
"text-conditional de novo generation": "Write in SMILES the described molecule: ",
}
def run_inference(
model_name_or_path: str,
task: str,
prompt: str,
num_beams: int,
):
instruction = task2prefix[task]
config = HuggingFaceSeq2SeqGenerator(
algorithm_version=model_name_or_path,
prefix=instruction,
prompt=prompt,
num_beams=num_beams,
)
model = HuggingFaceGenerationAlgorithm(config)
tokenizer = AutoTokenizer.from_pretrained("t5-small")
text = list(model.sample(1))[0]
text = text.replace(instruction + prompt, "")
text = text.split(tokenizer.eos_token)[0]
text = text.replace(tokenizer.pad_token, "")
text = text.strip()
return text
if __name__ == "__main__":
models = [
"text-chem-t5-small-standard",
"text-chem-t5-small-augm",
"text-chem-t5-base-standard",
"text-chem-t5-base-augm",
]
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="Text-chem-T5 model",
inputs=[
gr.Dropdown(
models,
label="Language model",
value="text-chem-t5-base-augm",
),
gr.Radio(
choices=[
"forward",
"retrosynthesis",
"paragraph to actions",
"molecular captioning",
"text-conditional de novo generation",
],
label="Task",
value="paragraph to actions",
),
gr.Textbox(
label="Text prompt",
placeholder="I'm a stochastic parrot.",
lines=1,
),
gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|