Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -107,9 +107,6 @@ def predict_single_text(text):
|
|
107 |
# Calculate the probabilities
|
108 |
probabilities = torch.sigmoid(logits).squeeze()
|
109 |
|
110 |
-
# Define the threshold for prediction
|
111 |
-
threshold = 0.3
|
112 |
-
|
113 |
# Get the predicted labels
|
114 |
predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
|
115 |
|
@@ -162,12 +159,13 @@ def predict_single_text(text):
|
|
162 |
|
163 |
# Create Gradio interface for single text
|
164 |
iface2 = gr.Interface(fn=predict_single_text,
|
165 |
-
inputs=gr.Textbox(lines=7, label="Paste or type text here"),
|
|
|
166 |
outputs=[gr.Label(label="Top Predictions", show_label=True),
|
167 |
gr.Plot(label="Likelihood of all labels", show_label=True)],
|
168 |
title="Single Text Prediction",
|
169 |
-
|
170 |
-
)
|
171 |
|
172 |
|
173 |
# UPLOAD CSV
|
@@ -245,8 +243,8 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
|
|
245 |
# Calculate the probabilities
|
246 |
predictions = torch.sigmoid(logits).squeeze()
|
247 |
|
248 |
-
# Define the threshold for prediction
|
249 |
-
threshold = 0.3
|
250 |
|
251 |
# Get the predicted labels
|
252 |
predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
|
@@ -352,13 +350,17 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
|
|
352 |
# Define the input component
|
353 |
file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
|
354 |
column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
|
|
|
355 |
|
356 |
# Create the Gradio interface
|
357 |
iface3 = gr.Interface(fn=predict_from_csv,
|
358 |
-
inputs=[file_input, column_name_input
|
|
|
359 |
outputs=gr.File(label='Download output CSV', show_label=True),
|
360 |
title="Multi-text Prediction",
|
361 |
-
description='**
|
|
|
|
|
362 |
|
363 |
# Create a tabbed interface
|
364 |
demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
|
|
|
107 |
# Calculate the probabilities
|
108 |
probabilities = torch.sigmoid(logits).squeeze()
|
109 |
|
|
|
|
|
|
|
110 |
# Get the predicted labels
|
111 |
predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
|
112 |
|
|
|
159 |
|
160 |
# Create Gradio interface for single text
|
161 |
iface2 = gr.Interface(fn=predict_single_text,
|
162 |
+
inputs=[gr.Textbox(lines=7, label="Paste or type text here"),
|
163 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
|
164 |
outputs=[gr.Label(label="Top Predictions", show_label=True),
|
165 |
gr.Plot(label="Likelihood of all labels", show_label=True)],
|
166 |
title="Single Text Prediction",
|
167 |
+
description="**Threshold value:** The threshold value determines the minimum probability required for a label to be predicted. A higher threshold value will result in fewer labels being predicted, while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3.",
|
168 |
+
article="**Note:** The quality of model predictions may depend on the quality of the information provided.")
|
169 |
|
170 |
|
171 |
# UPLOAD CSV
|
|
|
243 |
# Calculate the probabilities
|
244 |
predictions = torch.sigmoid(logits).squeeze()
|
245 |
|
246 |
+
# # Define the threshold for prediction
|
247 |
+
# threshold = 0.3
|
248 |
|
249 |
# Get the predicted labels
|
250 |
predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
|
|
|
350 |
# Define the input component
|
351 |
file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
|
352 |
column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
|
353 |
+
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")
|
354 |
|
355 |
# Create the Gradio interface
|
356 |
iface3 = gr.Interface(fn=predict_from_csv,
|
357 |
+
inputs=[file_input, column_name_input,
|
358 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
|
359 |
outputs=gr.File(label='Download output CSV', show_label=True),
|
360 |
title="Multi-text Prediction",
|
361 |
+
description='''**Threshold value:** The threshold value determines the minimum probability required
|
362 |
+
for a label to be predicted. A higher threshold value will result in fewer labels being predicted,
|
363 |
+
while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3''')
|
364 |
|
365 |
# Create a tabbed interface
|
366 |
demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
|