sadickam commited on
Commit
c3ce081
·
verified ·
1 Parent(s): e587583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -107,9 +107,6 @@ def predict_single_text(text):
107
  # Calculate the probabilities
108
  probabilities = torch.sigmoid(logits).squeeze()
109
 
110
- # Define the threshold for prediction
111
- threshold = 0.3
112
-
113
  # Get the predicted labels
114
  predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
115
 
@@ -162,12 +159,13 @@ def predict_single_text(text):
162
 
163
  # Create Gradio interface for single text
164
  iface2 = gr.Interface(fn=predict_single_text,
165
- inputs=gr.Textbox(lines=7, label="Paste or type text here"),
 
166
  outputs=[gr.Label(label="Top Predictions", show_label=True),
167
  gr.Plot(label="Likelihood of all labels", show_label=True)],
168
  title="Single Text Prediction",
169
- article="**Note:** The quality of model predictions may depend on the quality of information provided."
170
- )
171
 
172
 
173
  # UPLOAD CSV
@@ -245,8 +243,8 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
245
  # Calculate the probabilities
246
  predictions = torch.sigmoid(logits).squeeze()
247
 
248
- # Define the threshold for prediction
249
- threshold = 0.3
250
 
251
  # Get the predicted labels
252
  predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
@@ -352,13 +350,17 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
352
  # Define the input component
353
  file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
354
  column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
 
355
 
356
  # Create the Gradio interface
357
  iface3 = gr.Interface(fn=predict_from_csv,
358
- inputs=[file_input, column_name_input],
 
359
  outputs=gr.File(label='Download output CSV', show_label=True),
360
  title="Multi-text Prediction",
361
- description='**NOTE:** Please enter the column name containing the text to be analyzed.')
 
 
362
 
363
  # Create a tabbed interface
364
  demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
 
107
  # Calculate the probabilities
108
  probabilities = torch.sigmoid(logits).squeeze()
109
 
 
 
 
110
  # Get the predicted labels
111
  predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
112
 
 
159
 
160
  # Create Gradio interface for single text
161
  iface2 = gr.Interface(fn=predict_single_text,
162
+ inputs=[gr.Textbox(lines=7, label="Paste or type text here"),
163
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
164
  outputs=[gr.Label(label="Top Predictions", show_label=True),
165
  gr.Plot(label="Likelihood of all labels", show_label=True)],
166
  title="Single Text Prediction",
167
+ description="**Threshold value:** The threshold value determines the minimum probability required for a label to be predicted. A higher threshold value will result in fewer labels being predicted, while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3.",
168
+ article="**Note:** The quality of model predictions may depend on the quality of the information provided.")
169
 
170
 
171
  # UPLOAD CSV
 
243
  # Calculate the probabilities
244
  predictions = torch.sigmoid(logits).squeeze()
245
 
246
+ # # Define the threshold for prediction
247
+ # threshold = 0.3
248
 
249
  # Get the predicted labels
250
  predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
 
350
  # Define the input component
351
  file_input = gr.File(label="Upload CSV or Excel file here", show_label=True, file_types=[".csv", ".xls", ".xlsx"])
352
  column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
353
+ threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")
354
 
355
  # Create the Gradio interface
356
  iface3 = gr.Interface(fn=predict_from_csv,
357
+ inputs=[file_input, column_name_input,
358
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Threshold value (default=0.3)")],
359
  outputs=gr.File(label='Download output CSV', show_label=True),
360
  title="Multi-text Prediction",
361
+ description='''**Threshold value:** The threshold value determines the minimum probability required
362
+ for a label to be predicted. A higher threshold value will result in fewer labels being predicted,
363
+ while a lower threshold value will result in more labels being predicted. The default threshold value is 0.3''')
364
 
365
  # Create a tabbed interface
366
  demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],