clement-bonnet commited on
Commit
3e506b8
1 Parent(s): 9ef0f25

fix: pil image generation

Browse files
Files changed (7) hide show
  1. app.py +2 -2
  2. imgs/pattern_0.png +0 -0
  3. imgs/pattern_1.png +0 -0
  4. imgs/pattern_2.png +0 -0
  5. inference.py +5 -4
  6. inference_test.py +31 -0
  7. utils.py +22 -0
app.py CHANGED
@@ -75,10 +75,10 @@ with gr.Blocks() as demo:
75
 
76
  # Right column: Generated image
77
  with gr.Column(scale=1):
78
- output_image = gr.Image(label="Generated Image", height=400, width=400)
79
 
80
  # Handle click events
81
- coord_selector.select(process_click, inputs=[image_idx], outputs=output_image)
82
 
83
  # Launch the app
84
  demo.launch()
 
75
 
76
  # Right column: Generated image
77
  with gr.Column(scale=1):
78
+ output_image = gr.Image(label="Generated Image", height=308, width=328)
79
 
80
  # Handle click events
81
+ coord_selector.select(process_click, inputs=[image_idx], outputs=output_image, trigger_mode="multiple")
82
 
83
  # Launch the app
84
  demo.launch()
imgs/pattern_0.png ADDED
imgs/pattern_1.png CHANGED
imgs/pattern_2.png DELETED
Binary file (15.6 kB)
 
inference.py CHANGED
@@ -17,13 +17,13 @@ from huggingface_hub import snapshot_download
17
  # lpn imports
18
  from src.models.lpn import LPN
19
  from src.models.transformer import EncoderTransformer, DecoderTransformer
20
- from src.visualization import display_grid, ax_to_pil
21
 
22
- from utils import patch_target
23
 
24
 
25
  checkpoint_name = "quiet-thunder-789--checkpoint:v0"
26
- BLUE_LOCATION_INPUTS = {1: 13, 2: 9}
27
 
28
  local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
29
  with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
@@ -93,4 +93,5 @@ def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Ima
93
  output_grid = output_grids[0]
94
  _, ax = plt.subplots(1, 1, figsize=(4, 4))
95
  display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
96
- return ax_to_pil(ax)
 
 
17
  # lpn imports
18
  from src.models.lpn import LPN
19
  from src.models.transformer import EncoderTransformer, DecoderTransformer
20
+ from src.visualization import display_grid
21
 
22
+ from utils import patch_target, ax_to_pil
23
 
24
 
25
  checkpoint_name = "quiet-thunder-789--checkpoint:v0"
26
+ BLUE_LOCATION_INPUTS = {0: 13, 1: 9}
27
 
28
  local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
29
  with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
 
93
  output_grid = output_grids[0]
94
  _, ax = plt.subplots(1, 1, figsize=(4, 4))
95
  display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
96
+ pil_image = ax_to_pil(ax)
97
+ return pil_image
inference_test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from PIL import Image
4
+
5
+ from inference import generate_image
6
+
7
+
8
+ class TestGenerateImage(unittest.TestCase):
9
+ def test_generate_image_output_type(self):
10
+ img = generate_image(image_idx=0, x=0.5, y=0.5)
11
+ self.assertIsInstance(img, Image.Image)
12
+
13
+ def test_generate_image_valid_coordinates(self):
14
+ img = generate_image(image_idx=0, x=0.1, y=0.9)
15
+ self.assertIsInstance(img, Image.Image)
16
+
17
+ def test_generate_image_edge_coordinates(self):
18
+ img = generate_image(image_idx=1, x=0.0, y=1.0)
19
+ self.assertIsInstance(img, Image.Image)
20
+
21
+ def test_generate_image_invalid_image_idx(self):
22
+ with self.assertRaises(KeyError):
23
+ generate_image(image_idx=2, x=0.5, y=0.5)
24
+
25
+ def test_generate_image_eps_boundary(self):
26
+ img = generate_image(image_idx=0, x=1e-5, y=1 - 1e-5)
27
+ self.assertIsInstance(img, Image.Image)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ unittest.main()
utils.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import omegaconf
2
 
3
 
@@ -10,3 +14,21 @@ def patch_target(config):
10
  elif isinstance(value, str) and value.startswith("src_v2"):
11
  # Update the value if it matches the old_value
12
  config[key] = value.replace("src_v2", "src")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
  import omegaconf
6
 
7
 
 
14
  elif isinstance(value, str) and value.startswith("src_v2"):
15
  # Update the value if it matches the old_value
16
  config[key] = value.replace("src_v2", "src")
17
+
18
+
19
+ def ax_to_pil(ax):
20
+ fig = ax.figure
21
+ buf = io.BytesIO()
22
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
23
+ buf.seek(0)
24
+
25
+ # Load the image data completely before closing the buffer
26
+ pil_image = Image.open(buf)
27
+ pil_image_copy = pil_image.copy()
28
+
29
+ # Now we can safely close everything
30
+ pil_image.close()
31
+ buf.close()
32
+ plt.close(fig)
33
+
34
+ return pil_image_copy