Spaces:
Running
Running
prasanna kumar
commited on
Commit
·
e88497a
1
Parent(s):
43be51d
added text analytics and gemma model support
Browse files
app.py
CHANGED
@@ -1,52 +1,98 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer
|
3 |
import ast
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
|
8 |
# Available models
|
9 |
-
MODELS = ["Meta-Llama-3.1-8B","gemma-2b"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def process_input(input_type, input_value, model_name):
|
12 |
-
|
13 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path+model_name)
|
14 |
|
15 |
if input_type == "Text":
|
16 |
-
|
17 |
-
# Tokenize the text
|
18 |
-
token_ids = tokenizer.encode(input_value,add_special_tokens=True)
|
19 |
-
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
20 |
-
return len(tokens),character_count, tokens, token_ids
|
21 |
-
|
22 |
elif input_type == "Token IDs":
|
23 |
try:
|
24 |
token_ids = ast.literal_eval(input_value)
|
25 |
-
|
26 |
-
text = tokenizer.decode(token_ids)
|
27 |
-
# Create output strings
|
28 |
-
return len(token_ids),len(token_ids), text, input_value,
|
29 |
except ValueError:
|
30 |
-
return "Error", "Invalid input
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
if __name__ == "__main__":
|
51 |
-
iface.
|
52 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer
|
3 |
import ast
|
4 |
+
from collections import Counter
|
5 |
+
import re
|
6 |
+
import plotly.graph_objs as go
|
7 |
|
8 |
+
model_path = "models/"
|
9 |
|
10 |
# Available models
|
11 |
+
MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"]
|
12 |
+
|
13 |
+
def create_vertical_histogram(data, title):
|
14 |
+
labels, values = zip(*data) if data else ([], [])
|
15 |
+
fig = go.Figure(go.Bar(
|
16 |
+
x=labels,
|
17 |
+
y=values
|
18 |
+
))
|
19 |
+
fig.update_layout(
|
20 |
+
title=title,
|
21 |
+
xaxis_title="Item",
|
22 |
+
yaxis_title="Count",
|
23 |
+
height=400,
|
24 |
+
xaxis=dict(tickangle=-45)
|
25 |
+
)
|
26 |
+
return fig
|
27 |
|
28 |
def process_input(input_type, input_value, model_name):
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
|
|
|
30 |
|
31 |
if input_type == "Text":
|
32 |
+
text = input_value
|
|
|
|
|
|
|
|
|
|
|
33 |
elif input_type == "Token IDs":
|
34 |
try:
|
35 |
token_ids = ast.literal_eval(input_value)
|
36 |
+
text = tokenizer.decode(token_ids)
|
|
|
|
|
|
|
37 |
except ValueError:
|
38 |
+
return "Error", "Invalid input", "", "", "", None, None, None
|
39 |
|
40 |
+
character_count = len(text)
|
41 |
+
word_count = len(text.split())
|
42 |
+
|
43 |
+
token_ids = tokenizer.encode(text, add_special_tokens=True)
|
44 |
+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
45 |
+
|
46 |
+
space_count = sum(1 for token in tokens if token == '▁')
|
47 |
+
special_char_count = sum(1 for token in tokens if not token.isalnum() and token != '▁')
|
48 |
+
|
49 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
50 |
+
special_chars = re.findall(r'[^\w\s]', text)
|
51 |
+
numbers = re.findall(r'\d+', text)
|
52 |
+
|
53 |
+
most_common_words = Counter(words).most_common(10)
|
54 |
+
most_common_special_chars = Counter(special_chars).most_common(10)
|
55 |
+
most_common_numbers = Counter(numbers).most_common(10)
|
56 |
+
|
57 |
+
words_hist = create_vertical_histogram(most_common_words, "Most Common Words")
|
58 |
+
special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters")
|
59 |
+
numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers")
|
60 |
+
|
61 |
+
analysis = f"Token count: {len(tokens)}\n"
|
62 |
+
analysis += f"Character count: {character_count}\n"
|
63 |
+
analysis += f"Word count: {word_count}\n"
|
64 |
+
analysis += f"Space tokens: {space_count}\n"
|
65 |
+
analysis += f"Special character tokens: {special_char_count}\n"
|
66 |
+
analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}"
|
67 |
+
|
68 |
+
return analysis, " ".join(tokens), str(token_ids), words_hist, special_chars_hist, numbers_hist
|
69 |
+
|
70 |
+
with gr.Blocks() as iface:
|
71 |
+
gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!")
|
72 |
+
gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.")
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text")
|
76 |
+
model_name = gr.Dropdown(choices=MODELS, label="Select Model")
|
77 |
+
|
78 |
+
input_text = gr.Textbox(lines=5, label="Input")
|
79 |
+
|
80 |
+
submit_button = gr.Button("Process")
|
81 |
+
|
82 |
+
analysis_output = gr.Textbox(label="Analysis", lines=6)
|
83 |
+
tokens_output = gr.Textbox(label="Tokens", lines=3)
|
84 |
+
token_ids_output = gr.Textbox(label="Token IDs", lines=2)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
words_plot = gr.Plot(label="Most Common Words")
|
88 |
+
special_chars_plot = gr.Plot(label="Most Common Special Characters")
|
89 |
+
numbers_plot = gr.Plot(label="Most Common Numbers")
|
90 |
+
|
91 |
+
submit_button.click(
|
92 |
+
process_input,
|
93 |
+
inputs=[input_type, input_text, model_name],
|
94 |
+
outputs=[analysis_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot]
|
95 |
+
)
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
+
iface.launch()
|
|