graph method changes

#10
by drod75 - opened
Files changed (1) hide show
  1. app.py +20 -40
app.py CHANGED
@@ -14,9 +14,7 @@ from langchain.schema.output_parser import StrOutputParser
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
  from PIL import Image
16
  import json
17
- import matplotlib.pyplot as plt
18
- from matplotlib.colors import LinearSegmentedColormap
19
- import textwrap
20
 
21
  st.set_page_config(
22
  page_title="Food Chain",
@@ -281,45 +279,27 @@ def display_dishes_in_grid(dishes, cols=3):
281
  st.sidebar.write(dish.replace("_", " ").capitalize())
282
 
283
  def display_prediction_graph(class_names, confidences):
284
- #reversing them so graph displays highest predictions at the top
285
- confidences.reverse()
286
- class_names.reverse()
287
-
288
- #display as a graph
289
- norm = plt.Normalize(min(confidences), max(confidences))
290
- cmap = LinearSegmentedColormap.from_list("grey_orange", ["#808080", "#FFA500"]) #color map grey to orange
291
-
292
- fig, ax = plt.subplots(figsize=(12, 6))
293
- bars = ax.barh(class_names, confidences, color=cmap(norm(confidences)))
294
-
295
- fig.patch.set_alpha(0) # Transparent background
296
- ax.set_facecolor('none')
297
-
298
- min_width = 0.07 * ax.get_xlim()[1] # 7% of the x-axis range
299
- # Add labels inside the bars, aligned to the right
300
- for bar in bars:
301
- original_width = bar.get_width()
302
- width = original_width
303
- if width < min_width:
304
- width = min_width
305
- ax.text(width - 0.02, bar.get_y() + bar.get_height()/2, f'{original_width:.1f}%',
306
- va='center', ha='right', color='white', fontweight='bold', fontsize=16)
307
-
308
- ax.set_xticklabels([]) #remove x label
309
-
310
- # Wrapping labels
311
- max_label_width = 10
312
- labels = ax.get_yticklabels()
313
- wrapped_labels = [textwrap.fill(label.get_text(), width=max_label_width) for label in labels] # Wrap the labels if they exceed the max width
314
- ax.set_yticklabels(wrapped_labels, fontsize=16, color='white')
315
 
316
- #no borders
317
- for spine in ax.spines.values():
318
- spine.set_visible(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- ax.set_title(class_names[-1], color='white', fontsize=24, fontweight='bold', ha='left', x=0.5)
321
-
322
- st.pyplot(fig) # Display the plot
323
 
324
  # #Streamlit
325
 
 
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
  from PIL import Image
16
  import json
17
+ import plotly.graph_objects as go
 
 
18
 
19
  st.set_page_config(
20
  page_title="Food Chain",
 
279
  st.sidebar.write(dish.replace("_", " ").capitalize())
280
 
281
  def display_prediction_graph(class_names, confidences):
282
+ values = [round(confidence, 2) for confidence in confidences]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ fig = go.Figure(go.Bar(
285
+ x=values,
286
+ y=class_names,
287
+ orientation='h',
288
+ marker=dict(color='orange'),
289
+ text=values, # Display values on the bars
290
+ textposition='outside' # Position the text outside the bars
291
+ ))
292
+
293
+ # Update layout for better appearance
294
+ fig.update_layout(
295
+ title="Prediction Graph",
296
+ xaxis_title="Prediction Values",
297
+ yaxis_title="Prediction Categories",
298
+ yaxis=dict(autorange="reversed") # Reverse the y-axis to display top categories at the top
299
+ )
300
 
301
+ # Display the chart in Streamlit
302
+ st.plotly_chart(fig)
 
303
 
304
  # #Streamlit
305