drod75 commited on
Commit
f5ac0be
·
verified ·
1 Parent(s): 2d0046e

graph method changes

Browse files
Files changed (1) hide show
  1. app.py +29 -41
app.py CHANGED
@@ -14,9 +14,8 @@ from langchain.schema.output_parser import StrOutputParser
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
  from PIL import Image
16
  import json
17
- import matplotlib.pyplot as plt
18
- from matplotlib.colors import LinearSegmentedColormap
19
- import textwrap
20
 
21
  st.set_page_config(
22
  page_title="Food Chain",
@@ -281,45 +280,34 @@ def display_dishes_in_grid(dishes, cols=3):
281
  st.sidebar.write(dish.replace("_", " ").capitalize())
282
 
283
  def display_prediction_graph(class_names, confidences):
284
- #reversing them so graph displays highest predictions at the top
285
- confidences.reverse()
286
- class_names.reverse()
287
-
288
- #display as a graph
289
- norm = plt.Normalize(min(confidences), max(confidences))
290
- cmap = LinearSegmentedColormap.from_list("grey_orange", ["#808080", "#FFA500"]) #color map grey to orange
291
-
292
- fig, ax = plt.subplots(figsize=(12, 6))
293
- bars = ax.barh(class_names, confidences, color=cmap(norm(confidences)))
294
-
295
- fig.patch.set_alpha(0) # Transparent background
296
- ax.set_facecolor('none')
297
-
298
- min_width = 0.07 * ax.get_xlim()[1] # 7% of the x-axis range
299
- # Add labels inside the bars, aligned to the right
300
- for bar in bars:
301
- original_width = bar.get_width()
302
- width = original_width
303
- if width < min_width:
304
- width = min_width
305
- ax.text(width - 0.02, bar.get_y() + bar.get_height()/2, f'{original_width:.1f}%',
306
- va='center', ha='right', color='white', fontweight='bold', fontsize=16)
307
-
308
- ax.set_xticklabels([]) #remove x label
309
-
310
- # Wrapping labels
311
- max_label_width = 10
312
- labels = ax.get_yticklabels()
313
- wrapped_labels = [textwrap.fill(label.get_text(), width=max_label_width) for label in labels] # Wrap the labels if they exceed the max width
314
- ax.set_yticklabels(wrapped_labels, fontsize=16, color='white')
315
-
316
- #no borders
317
- for spine in ax.spines.values():
318
- spine.set_visible(False)
319
 
320
- ax.set_title(class_names[-1], color='white', fontsize=24, fontweight='bold', ha='left', x=0.5)
321
-
322
- st.pyplot(fig) # Display the plot
323
 
324
  # #Streamlit
325
 
 
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
  from PIL import Image
16
  import json
17
+ import plotly.graph_objects as go
18
+
 
19
 
20
  st.set_page_config(
21
  page_title="Food Chain",
 
280
  st.sidebar.write(dish.replace("_", " ").capitalize())
281
 
282
  def display_prediction_graph(class_names, confidences):
283
+ values = [round(value, 2) for value in confidences]
284
+
285
+ # Create a horizontal bar chart
286
+ fig = go.Figure(go.Bar(
287
+ x=values,
288
+ y=class_names,
289
+ orientation='h',
290
+ marker=dict(color='orange'),
291
+ text=values, # Display values on the bars
292
+ textposition='outside' # Position the text outside the bars
293
+ ))
294
+
295
+ # Update layout for better appearance
296
+ fig.update_layout(
297
+ title="Prediction Graph",
298
+ title_font=dict(color="black"),
299
+ xaxis_title="Probability",
300
+ yaxis_title="Categories",
301
+ margin=dict(l=20, r=20, t=60, b=20),
302
+ xaxis=dict(title_font=dict(color="black"), tickfont=dict(color="black")), # Ensure x-axis labels and title are black
303
+ yaxis=dict(autorange="reversed", title_font=dict(color="black"), tickfont=dict(color="black")), # Ensure y-axis labels and title are black
304
+ plot_bgcolor='rgba(230, 230, 230, 1)', # Classic gray background for the plot area
305
+ paper_bgcolor='rgba(240, 240, 240, 1)', # Lighter gray background for the paper area
306
+ font=dict(color="black") # Set font color to black for better contrast
307
+ )
 
 
 
 
 
 
 
 
 
 
308
 
309
+ # Display the chart in Streamlit
310
+ st.plotly_chart(fig)
 
311
 
312
  # #Streamlit
313