Tonic commited on
Commit
2ca0200
1 Parent(s): 8c1d821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -14
app.py CHANGED
@@ -14,6 +14,7 @@ tokenizer.pad_token = tokenizer.eos_token
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.to(device)
16
 
 
17
  def historical_generation(prompt, max_new_tokens=600):
18
  prompt = f"### Text ###\n{prompt}"
19
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
@@ -37,9 +38,10 @@ def historical_generation(prompt, max_new_tokens=600):
37
 
38
  # Decode the generated text
39
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
-
41
- # Remove the prompt from the generated text
42
- generated_text = generated_text.replace("### Text ###\n", "").strip()
 
43
 
44
  # Tokenize the generated text
45
  tokens = tokenizer.tokenize(generated_text)
@@ -47,17 +49,46 @@ def historical_generation(prompt, max_new_tokens=600):
47
  # Create highlighted text output
48
  highlighted_text = []
49
  for token in tokens:
50
- # Remove special tokens and get the token type
51
- clean_token = token.replace("Ġ", "").replace("</w>", "")
52
  token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
53
-
54
  highlighted_text.append((clean_token, token_type))
55
 
56
  return highlighted_text
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Create Gradio interface
59
  iface = gr.Interface(
60
- fn=historical_generation,
61
  inputs=[
62
  gr.Textbox(
63
  label="Prompt",
@@ -72,15 +103,106 @@ iface = gr.Interface(
72
  value=600
73
  )
74
  ],
75
- outputs=gr.HighlightedText(
76
- label="Generated Historical Text",
77
- combine_adjacent=True,
78
- show_legend=True
79
- ),
 
 
 
 
80
  title="Historical Text Generation with OCRonos-Vintage",
81
- description="Generate historical-style text using the OCRonos-Vintage model. The output shows token types as highlights.",
82
  theme=gr.themes.Base()
83
  )
84
 
85
  if __name__ == "__main__":
86
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.to(device)
16
 
17
+ # Function for generating text
18
  def historical_generation(prompt, max_new_tokens=600):
19
  prompt = f"### Text ###\n{prompt}"
20
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
 
38
 
39
  # Decode the generated text
40
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
41
+
42
+ # Extract text after "### Correction ###"
43
+ if "### Correction ###" in generated_text:
44
+ generated_text = generated_text.split("### Correction ###")[1].strip()
45
 
46
  # Tokenize the generated text
47
  tokens = tokenizer.tokenize(generated_text)
 
49
  # Create highlighted text output
50
  highlighted_text = []
51
  for token in tokens:
52
+ # Clean token and get token type
53
+ clean_token = token.replace("Ġ", "")
54
  token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
 
55
  highlighted_text.append((clean_token, token_type))
56
 
57
  return highlighted_text
58
 
59
+ # Tokenizer information display
60
+ import os
61
+ os.system('python -m spacy download en_core_web_sm')
62
+ import spacy
63
+ from spacy import displacy
64
+
65
+ nlp = spacy.load("en_core_web_sm")
66
+
67
+ def text_analysis(text):
68
+ doc = nlp(text)
69
+ html = displacy.render(doc, style="dep", page=True)
70
+ html = (
71
+ "<div style='max-width:100%; max-height:360px; overflow:auto'>"
72
+ + html
73
+ + "</div>"
74
+ )
75
+ pos_count = {
76
+ "char_count": len(text),
77
+ "token_count": len(list(doc)),
78
+ }
79
+ pos_tokens = [(token.text, token.pos_) for token in doc]
80
+
81
+ return pos_tokens, pos_count, html
82
+
83
+ # Gradio interface for text analysis
84
+ def full_interface(prompt, max_new_tokens):
85
+ generated_highlight = historical_generation(prompt, max_new_tokens)
86
+ tokens, pos_count, html = text_analysis(prompt)
87
+ return generated_highlight, pos_count, html
88
+
89
  # Create Gradio interface
90
  iface = gr.Interface(
91
+ fn=full_interface,
92
  inputs=[
93
  gr.Textbox(
94
  label="Prompt",
 
103
  value=600
104
  )
105
  ],
106
+ outputs=[
107
+ gr.HighlightedText(
108
+ label="Generated Historical Text",
109
+ combine_adjacent=True,
110
+ show_legend=True
111
+ ),
112
+ gr.JSON(label="Tokenizer Info"),
113
+ gr.HTML(label="Dependency Parse Visualization")
114
+ ],
115
  title="Historical Text Generation with OCRonos-Vintage",
116
+ description="Generate historical-style text using OCRonos-Vintage and analyze the tokenizer output.",
117
  theme=gr.themes.Base()
118
  )
119
 
120
  if __name__ == "__main__":
121
+ iface.launch()
122
+
123
+ # import torch
124
+ # from transformers import GPT2LMHeadModel, GPT2Tokenizer
125
+ # import gradio as gr
126
+
127
+ # Load pre-trained model and tokenizer
128
+ # model_name = "PleIAs/OCRonos-Vintage"
129
+ # model = GPT2LMHeadModel.from_pretrained(model_name)
130
+ # tokenizer = GPT2Tokenizer.from_pretrained(model_name)
131
+
132
+ # Set the pad token to be the same as the eos token
133
+ # tokenizer.pad_token = tokenizer.eos_token
134
+
135
+ # Set the device to GPU if available, otherwise use CPU
136
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
+ # model.to(device)
138
+
139
+ # def historical_generation(prompt, max_new_tokens=600):
140
+ # prompt = f"### Text ###\n{prompt}"
141
+ # inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
142
+ # input_ids = inputs["input_ids"].to(device)
143
+ # attention_mask = inputs["attention_mask"].to(device)
144
+
145
+ # Generate text
146
+ # output = model.generate(
147
+ # input_ids,
148
+ # attention_mask=attention_mask,
149
+ # max_new_tokens=max_new_tokens,
150
+ # pad_token_id=tokenizer.eos_token_id,
151
+ # top_k=50,
152
+ # temperature=0.3,
153
+ # top_p=0.95,
154
+ # do_sample=True,
155
+ # repetition_penalty=1.5,
156
+ # bos_token_id=tokenizer.bos_token_id,
157
+ # eos_token_id=tokenizer.eos_token_id
158
+ # )
159
+
160
+ # Decode the generated text
161
+ # generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
162
+
163
+ # Remove the prompt from the generated text
164
+ # generated_text = generated_text.replace("### Text ###\n", "").strip()
165
+
166
+ # Tokenize the generated text
167
+ # tokens = tokenizer.tokenize(generated_text)
168
+
169
+ # Create highlighted text output
170
+ # highlighted_text = []
171
+ # for token in tokens:
172
+ # Remove special tokens and get the token type
173
+ # clean_token = token.replace("Ġ", "").replace("</w>", "")
174
+ # token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
175
+
176
+ # highlighted_text.append((clean_token, token_type))
177
+
178
+ # return highlighted_text
179
+
180
+ # Create Gradio interface
181
+ # iface = gr.Interface(
182
+ # fn=historical_generation,
183
+ # inputs=[
184
+ # gr.Textbox(
185
+ # label="Prompt",
186
+ # placeholder="Enter a prompt for historical text generation...",
187
+ # lines=3
188
+ # ),
189
+ # gr.Slider(
190
+ # label="Max New Tokens",
191
+ # minimum=50,
192
+ # maximum=1000,
193
+ # step=50,
194
+ # value=600
195
+ # )
196
+ # ],
197
+ # outputs=gr.HighlightedText(
198
+ # label="Generated Historical Text",
199
+ # combine_adjacent=True,
200
+ # show_legend=True
201
+ # ),
202
+ # title="Historical Text Generation with OCRonos-Vintage",
203
+ # description="Generate historical-style text using the OCRonos-Vintage model. The output shows token types as highlights.",
204
+ # theme=gr.themes.Base()
205
+ # )
206
+
207
+ # if __name__ == "__main__":
208
+ # iface.launch()