sadickam commited on
Commit
a26bec3
·
verified ·
1 Parent(s): acd6fc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -129,7 +129,7 @@ def predict_single_text(text):
129
  predicted_prob = [round(a_, 3) for a_ in probabilities.cpu().numpy().tolist() if a_ > threshold]
130
 
131
  # Create a dictionary containing the top predicted IEQ labels and their corresponding probabilities
132
- top_prediction = (dict(zip(predicted_labels, predicted_prob)))
133
 
134
  # Create a bar chart showing the likelihood of each IEQ label
135
  # Make dataframe for plotly bar chart
@@ -141,7 +141,7 @@ def predict_single_text(text):
141
  df2['Likelihood'] = n
142
 
143
  # plot graph of predictions
144
- fig = px.bar(df2, x="Likelihood", y="IEQ", orientation="h")
145
 
146
  fig.update_layout(
147
  # barmode='stack',
@@ -160,9 +160,7 @@ def predict_single_text(text):
160
 
161
  return top_prediction, fig
162
 
163
- # Create Gradio interface for single text
164
-
165
-
166
  iface2 = gr.Interface(fn=predict_single_text,
167
  inputs=gr.Textbox(lines=7, label="Paste or type text here"),
168
  outputs=[gr.Label(label="Top Prediction", show_label=True),
@@ -333,22 +331,22 @@ def predict_from_csv(file, column_name, progress=gr.Progress()):
333
  # Create a downloadable CSV file
334
  output_csv = gr.File(value='IEQ_predictions.csv', visible=True)
335
 
336
- # Create a histogram showing the frequency of each IEQ label
337
- fig = px.histogram(df_docs, y="IEQ_predicted")
338
- fig.update_layout(
339
- template='seaborn',
340
- font=dict(family="Arial", size=12, color="black"),
341
- autosize=True,
342
- # width=800,
343
- # height=500,
344
- xaxis_title="IEQ counts",
345
- yaxis_title="Indoor environmental quality (IEQ)",
346
- )
347
- fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
348
- fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
349
- fig.update_annotations(font_size=12)
350
 
351
- return fig, output_csv
352
 
353
 
354
  # Define the input component
@@ -358,8 +356,7 @@ column_name_input = gr.Textbox(label="Enter the column name containing the text
358
  # Create the Gradio interface
359
  iface3 = gr.Interface(fn=predict_from_csv,
360
  inputs=[file_input, column_name_input],
361
- outputs=[gr.Plot(label='Frequency of IEQs', show_label=True),
362
- gr.File(label='Download output CSV', show_label=True)],
363
  title="Multi-text Prediction",
364
  description='**NOTE:** Please enter the column name containing the text to be analyzed.')
365
 
 
129
  predicted_prob = [round(a_, 3) for a_ in probabilities.cpu().numpy().tolist() if a_ > threshold]
130
 
131
  # Create a dictionary containing the top predicted IEQ labels and their corresponding probabilities
132
+ top_prediction = predicted_labels
133
 
134
  # Create a bar chart showing the likelihood of each IEQ label
135
  # Make dataframe for plotly bar chart
 
141
  df2['Likelihood'] = n
142
 
143
  # plot graph of predictions
144
+ fig = px.bar(df2, x="Likelihood", y="IEQ", orientation="v")
145
 
146
  fig.update_layout(
147
  # barmode='stack',
 
160
 
161
  return top_prediction, fig
162
 
163
+ # Create Gradio interface for single text
 
 
164
  iface2 = gr.Interface(fn=predict_single_text,
165
  inputs=gr.Textbox(lines=7, label="Paste or type text here"),
166
  outputs=[gr.Label(label="Top Prediction", show_label=True),
 
331
  # Create a downloadable CSV file
332
  output_csv = gr.File(value='IEQ_predictions.csv', visible=True)
333
 
334
+ # # Create a histogram showing the frequency of each IEQ label
335
+ # fig = px.histogram(df_docs, y="IEQ_predicted")
336
+ # fig.update_layout(
337
+ # template='seaborn',
338
+ # font=dict(family="Arial", size=12, color="black"),
339
+ # autosize=True,
340
+ # # width=800,
341
+ # # height=500,
342
+ # xaxis_title="IEQ counts",
343
+ # yaxis_title="Indoor environmental quality (IEQ)",
344
+ # )
345
+ # fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
346
+ # fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
347
+ # fig.update_annotations(font_size=12)
348
 
349
+ return output_csv, #fig
350
 
351
 
352
  # Define the input component
 
356
  # Create the Gradio interface
357
  iface3 = gr.Interface(fn=predict_from_csv,
358
  inputs=[file_input, column_name_input],
359
+ outputs=[gr.File(label='Download output CSV', show_label=True)],
 
360
  title="Multi-text Prediction",
361
  description='**NOTE:** Please enter the column name containing the text to be analyzed.')
362