File size: 7,251 Bytes
cd851c8
 
 
 
 
 
 
 
f593a0f
 
cd851c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f39ba
cd851c8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
from typing import Any, Callable, List, Optional, Tuple

import nltk
nltk.download('punkt')
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

print(gr.__version__)

# A folderpath for where the examples are stored
EXAMPLES_FOLDER_NAME = "examples"

# A List of repo names for the huggingface models available for inference
HF_MODELS = ["huggingface/facebook/bart-large-cnn", 
        "huggingface/sshleifer/distilbart-xsum-12-6", 
        "huggingface/google/pegasus-xsum", 
        "huggingface/philschmid/bart-large-cnn-samsum", 
        "huggingface/linydub/bart-large-samsum",
        "huggingface/philschmid/distilbart-cnn-12-6-samsum",
        "huggingface/knkarthick/MEETING-SUMMARY-BART-LARGE-XSUM-SAMSUM-DIALOGSUM-AMI",
]


################################################################################
# Functions: Document statistics
################################################################################
# Function that uses a huggingface tokenizer to count how many tokens are in a text
def count_tokens(input_text, model_path='sshleifer/distilbart-cnn-12-6'):
    # Load a huggingface tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Tokenize the text
    tokens = tokenizer(input_text)
    # Count the number of tokens
    return len(tokens['input_ids'])

# Function that uses nltk to count sentences in a text
def count_sentences(input_text):
    # Use nltk to count sentences in the text
    number_of_sentences = nltk.sent_tokenize(input_text)
    # Return the number of sentences
    return len(number_of_sentences)

# Function that counts the number of words in a text
def count_words(input_text):
    # Use nltk to count words in the text
    number_of_words = nltk.word_tokenize(input_text)
    # Return the number of words
    return len(number_of_words)

# Function that computes a few document statistics such as the number of tokens, sentences, and words
def compute_stats(input_text, models: Optional[List[str]] = None):
    # Count the number of tokens
    num_tokens = count_tokens(input_text)
    # Count the number of sentences
    num_sentences = count_sentences(input_text)
    # Count the number of words
    num_words = count_words(input_text)
    # Return the document statistics formatted as a string
    output_str = "| Tokens: {0} \n| Sentences: {1} \n| Words: {2}".format(num_tokens, num_sentences, num_words) + "\n"
    output_str += "The max number of tokens for the model is: 1024" + "\n" # I manually set 1024 as the max. I don't intend to use any models that are smaller anyway.
    # output_str += "Number of documents splits: 17.5"
    return output_str

# # A function to loop through a list of strings
# # returning the last element in the filepath for each string
# def get_file_names(file_paths):
#     # Create a list of file names
#     file_names = []
#     # Loop through the file paths
#     for file_path in file_paths:
#         # Get the last element in the file path
#         file_name = file_path.split('/')[-2:]
#         # Add the file name to the list
#         file_names.append(file_name)
#     # Loop through the file names and append to a string
#     file_names_str = ""
#     for file_name in file_names:
#         breakpoint()
#         file_names_str += file_name[0] + "\n"
#     # Return the list of file names
#     return file_names_str

################################################################################
# Functions: Huggingface Inference
################################################################################

# Function that uses a huggingface pipeline to predict a summary of a text
# input is a text string of a dialog conversation
def predict(dialog_text):
    # Load a huggingface model
    model = pipeline('summarization', model="philschmid/bart-large-cnn-samsum") #model='sshleifer/distilbart-cnn-12-6')
    # Build tokenizer_kwargs to set a max length and truncate the data on inference
    tokenizer_kwargs = {'truncation': True, 'max_length': 1024}
    # Use the model to predict a summary of the text
    summary = model(dialog_text, **tokenizer_kwargs)
    # Return the summary w/ the model name
    output = f"{hf_model_name} output: {summary[0]['summary_text']}"
    return output, "output2"

def recursive_predict(dialog_text: str, hf_model_name: Tuple[str]):
    breakpoint()
    asdf = "asdf"
    return output

################################################################################
# Functions: Gradio Utilities
################################################################################
# Function to build examples for gradio app
# Load text files from the examples folder as a list of strings for gradio
def get_examples(folder_path):
    # Create a list of strings
    examples = []
    # Loop through the files in the folder
    for file in os.listdir(folder_path):
        # Load the file
        with open(os.path.join(folder_path, file), 'r') as f:
            # Add the file to the list
            examples.append([f.read(), ["None"]])
    # Return the list of strings
    return examples

# A function that loops through a list of model paths, creates a gradio interface with the 
# model name, and adds it to the list of interfaces
# It outputs a list of interfaces
def get_hf_interfaces(models_to_load):
    # Create a list of interfaces
    interfaces = []
    # Loop through the HF_MODELS
    for model in models_to_load:
        # Create a gradio interface with the model name
        interface = gr.Interface.load(model, title="this is a test TITLE", alias="this is an ALIAS")
        # Add the interface to the list
        interfaces.append(interface)
    # Return the list of interfaces
    return interfaces

################################################################################
# Build Gradio app
################################################################################
# print_details = gr.Interface(
#     fn=lambda x: get_file_names(HF_MODELS),
#     inputs="text",
#     outputs="text",
#     title="Statistics of the document"
# )
# Outputs a string of various document statistics
document_statistics = gr.Interface(
    fn=compute_stats,
    inputs="text",
    outputs="text",
    title="Statistics of the document"
)
maddie_mixer_summarization = gr.Interface(
    fn=recursive_predict,
    inputs="text",
    outputs="text",
    title="Statistics of the document"
)

# Build Examples to pass along to the gradio app
examples = get_examples(EXAMPLES_FOLDER_NAME)

# Build a list of huggingface interfaces from model paths, 
# then add document statistics, and any custom interfaces
all_interfaces = get_hf_interfaces(HF_MODELS)
all_interfaces.insert(0, document_statistics) # Insert the statistics interface at the beginning
# all_interfaces.insert(0, print_details)
# all_interfaces.append(maddie_mixer_summarization) # Add the interface for the maddie mixer

# Build app
app = gr.Parallel(*all_interfaces,
            title='Text Summarizer (Maddie Custom)',
            description="Write a summary of a text",
            # examples=examples,
            inputs=gr.inputs.Textbox(lines = 10, label="Text"),
            )

# Launch
app.launch(inbrowser=True, show_error=True)