mlabonne commited on
Commit
fab1822
1 Parent(s): 171cab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -8,6 +8,8 @@ from matplotlib.patches import Patch
8
  from collections import defaultdict
9
  from networkx.drawing.nx_pydot import graphviz_layout
10
  from io import BytesIO
 
 
11
 
12
  TITLE = """
13
  <div align="center">
@@ -163,13 +165,16 @@ def create_family_tree(start_model):
163
 
164
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
165
 
166
- # Instead of plt.show(), capture the plot as an image in memory
167
  img_buffer = BytesIO()
168
- plt.savefig(img_buffer, format='png')
169
  plt.close()
170
- img_buffer.seek(0) # Rewind the buffer to the beginning
 
 
 
171
 
172
- return img_buffer.getvalue()
173
 
174
  with gr.Blocks() as demo:
175
  gr.Markdown(TITLE)
 
8
  from collections import defaultdict
9
  from networkx.drawing.nx_pydot import graphviz_layout
10
  from io import BytesIO
11
+ from PIL import Image
12
+
13
 
14
  TITLE = """
15
  <div align="center">
 
165
 
166
  plt.title(f"{start_model}'s Family Tree", fontsize=20)
167
 
168
+ # Capture the plot as an image in memory
169
  img_buffer = BytesIO()
170
+ plt.savefig(img_buffer, format='png', bbox_inches='tight')
171
  plt.close()
172
+ img_buffer.seek(0)
173
+
174
+ # Open the image using PIL
175
+ img = Image.open(img_buffer)
176
 
177
+ return img
178
 
179
  with gr.Blocks() as demo:
180
  gr.Markdown(TITLE)