christofid commited on
Commit
7198503
1 Parent(s): bf1c57e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -56
app.py CHANGED
@@ -3,70 +3,45 @@ import pathlib
3
  import gradio as gr
4
  import pandas as pd
5
  from gt4sd.algorithms.generation.hugging_face import (
6
- HuggingFaceCTRLGenerator,
7
- HuggingFaceGenerationAlgorithm,
8
- HuggingFaceGPT2Generator,
9
- HuggingFaceTransfoXLGenerator,
10
- HuggingFaceOpenAIGPTGenerator,
11
- HuggingFaceXLMGenerator,
12
- HuggingFaceXLNetGenerator,
13
  )
14
- from gt4sd.algorithms.registry import ApplicationsRegistry
15
-
16
 
17
  logger = logging.getLogger(__name__)
18
  logger.addHandler(logging.NullHandler())
19
 
20
- MODEL_FN = {
21
- "HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator,
22
- "HuggingFaceGPT2Generator": HuggingFaceGPT2Generator,
23
- "HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator,
24
- "HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator,
25
- "HuggingFaceXLMGenerator": HuggingFaceXLMGenerator,
26
- "HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator,
27
- }
28
-
29
-
30
  def run_inference(
31
- model_type: str,
32
- prompt: str,
33
- length: float,
34
- temperature: float,
35
  prefix: str,
36
- k: float,
37
- p: float,
38
- repetition_penalty: float,
39
  ):
40
- model = model_type.split("_")[0]
41
- version = model_type.split("_")[1]
42
 
43
- if model not in MODEL_FN.keys():
44
- raise ValueError(f"Model type {model} not supported")
45
- config = MODEL_FN[model](
46
- algorithm_version=version,
47
- prompt=prompt,
48
- length=length,
49
- temperature=temperature,
50
- repetition_penalty=repetition_penalty,
51
- k=k,
52
- p=p,
53
  prefix=prefix,
 
 
54
  )
55
 
56
  model = HuggingFaceGenerationAlgorithm(config)
 
 
57
  text = list(model.sample(1))[0]
58
 
 
 
 
 
59
  return text
60
 
61
 
62
  if __name__ == "__main__":
63
 
64
  # Preparation (retrieve all available algorithms)
65
- all_algos = ApplicationsRegistry.list_available()
66
- algos = [
67
- x["algorithm_application"] + "_" + x["algorithm_version"]
68
- for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos))
69
- ]
70
 
71
  # Load metadata
72
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
@@ -83,28 +58,22 @@ if __name__ == "__main__":
83
 
84
  demo = gr.Interface(
85
  fn=run_inference,
86
- title="HuggingFace language models",
87
  inputs=[
88
  gr.Dropdown(
89
- algos,
90
  label="Language model",
91
- value="HuggingFaceGPT2Generator_gpt2",
 
 
 
92
  ),
93
  gr.Textbox(
94
  label="Text prompt",
95
  placeholder="I'm a stochastic parrot.",
96
  lines=1,
97
  ),
98
- gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1),
99
- gr.Slider(
100
- minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature"
101
- ),
102
- gr.Textbox(
103
- label="Prefix", placeholder="Some prefix (before the prompt)", lines=1
104
- ),
105
- gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
106
- gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1),
107
- gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"),
108
  ],
109
  outputs=gr.Textbox(label="Output"),
110
  article=article,
 
3
  import gradio as gr
4
  import pandas as pd
5
  from gt4sd.algorithms.generation.hugging_face import (
6
+ HuggingFaceSeq2SeqGenerator,
7
+ HuggingFaceGenerationAlgorithm
 
 
 
 
 
8
  )
9
+ from transformers import AutoTokenizer
 
10
 
11
  logger = logging.getLogger(__name__)
12
  logger.addHandler(logging.NullHandler())
13
 
 
 
 
 
 
 
 
 
 
 
14
  def run_inference(
15
+ model_name_or_path: str,
 
 
 
16
  prefix: str,
17
+ prompt: str,
18
+ num_beams: int,
 
19
  ):
 
 
20
 
21
+ config = HuggingFaceSeq2SeqGenerator(
22
+ algorithm_version=model_name_or_path,
 
 
 
 
 
 
 
 
23
  prefix=prefix,
24
+ prompt=prompt,
25
+ num_beams=num_beams
26
  )
27
 
28
  model = HuggingFaceGenerationAlgorithm(config)
29
+ tokenizer = AutoTokenizer.from_pretrained("t5-small")
30
+
31
  text = list(model.sample(1))[0]
32
 
33
+ text = text.split(tokenizer.eos_token)[0]
34
+ text = text.replace(tokenizer.pad_token, "")
35
+ text = text.strip()
36
+
37
  return text
38
 
39
 
40
  if __name__ == "__main__":
41
 
42
  # Preparation (retrieve all available algorithms)
43
+ models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm",
44
+ "text-chem-t5-base-standard", "text-chem-t5-base-augm"]
 
 
 
45
 
46
  # Load metadata
47
  metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
 
58
 
59
  demo = gr.Interface(
60
  fn=run_inference,
61
+ title="Text-chem-T5 model",
62
  inputs=[
63
  gr.Dropdown(
64
+ models,
65
  label="Language model",
66
+ value="text-chem-t5-base-augm",
67
+ ),
68
+ gr.Textbox(
69
+ label="Prefix", placeholder="A task-specific prefix", lines=1
70
  ),
71
  gr.Textbox(
72
  label="Text prompt",
73
  placeholder="I'm a stochastic parrot.",
74
  lines=1,
75
  ),
76
+ gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
 
 
 
 
 
 
 
 
 
77
  ],
78
  outputs=gr.Textbox(label="Output"),
79
  article=article,