mischeiwiller commited on
Commit
9526595
·
verified ·
1 Parent(s): 3b00c47

fix: resolve IndexError in line fitting visualization

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -1,13 +1,10 @@
1
- # Example showing how to fit a 2d line with kornia / pytorch
2
  import matplotlib.pyplot as plt
3
  import torch
4
  import matplotlib
5
  matplotlib.use('Agg')
6
- import matplotlib.pyplot as plt
7
  import gradio as gr
8
  from kornia.geometry.line import ParametrizedLine, fit_line
9
 
10
-
11
  def inference(point1, point2, point3, point4):
12
  std = 1.2 # standard deviation for the points
13
  num_points = 50 # total number of points
@@ -16,7 +13,7 @@ def inference(point1, point2, point3, point4):
16
  p0 = torch.tensor([point1, point2], dtype=torch.float32)
17
  p1 = torch.tensor([point3, point4], dtype=torch.float32)
18
  l1 = ParametrizedLine.through(p0, p1)
19
-
20
  # sample some points and weights
21
  pts, w = [], []
22
  for t in torch.linspace(-10, 10, num_points):
@@ -25,37 +22,41 @@ def inference(point1, point2, point3, point4):
25
  p2 += p2_noise
26
  pts.append(p2)
27
  w.append(1 - p2_noise.mean())
 
28
  pts = torch.stack(pts)
29
  w = torch.stack(w)
 
30
  if len(pts.shape) == 2:
31
- pts = pts[None]
32
  if len(w.shape) == 1:
33
- w = w[None]
34
-
35
  l2 = fit_line(pts, w)
36
-
37
- # project some points along the estimated line
38
- p3 = l2.point_at(-10)
39
- p4 = l2.point_at(10)
40
 
41
- X = torch.stack((p3, p4)).detach().numpy()
42
- X_pts = pts.detach().numpy()
 
 
 
43
 
44
- fig = plt.figure()
45
- plt.plot(X_pts[:, 0], X_pts[:, 1], 'ro')
46
- plt.plot(X[:, 0], X[:, 1])
 
 
47
  return fig
48
 
49
  inputs = [
50
- gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 1"),
51
- gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 2"),
52
- gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 3"),
53
- gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 4"),
54
  ]
 
55
  outputs = gr.Plot()
56
 
57
  examples = [
58
- [0.0, 0.0, 1.0, 1.0],
59
  ]
60
 
61
  title = 'Line Fitting'
@@ -69,6 +70,6 @@ demo = gr.Interface(
69
  theme='huggingface',
70
  live=True,
71
  examples=examples,
72
-
73
  )
74
- demo.launch()
 
 
 
1
  import matplotlib.pyplot as plt
2
  import torch
3
  import matplotlib
4
  matplotlib.use('Agg')
 
5
  import gradio as gr
6
  from kornia.geometry.line import ParametrizedLine, fit_line
7
 
 
8
  def inference(point1, point2, point3, point4):
9
  std = 1.2 # standard deviation for the points
10
  num_points = 50 # total number of points
 
13
  p0 = torch.tensor([point1, point2], dtype=torch.float32)
14
  p1 = torch.tensor([point3, point4], dtype=torch.float32)
15
  l1 = ParametrizedLine.through(p0, p1)
16
+
17
  # sample some points and weights
18
  pts, w = [], []
19
  for t in torch.linspace(-10, 10, num_points):
 
22
  p2 += p2_noise
23
  pts.append(p2)
24
  w.append(1 - p2_noise.mean())
25
+
26
  pts = torch.stack(pts)
27
  w = torch.stack(w)
28
+
29
  if len(pts.shape) == 2:
30
+ pts = pts.unsqueeze(0)
31
  if len(w.shape) == 1:
32
+ w = w.unsqueeze(0)
33
+
34
  l2 = fit_line(pts, w)
 
 
 
 
35
 
36
+ # project some points along the estimated line
37
+ p3 = l2.point_at(torch.tensor(-10.0))
38
+ p4 = l2.point_at(torch.tensor(10.0))
39
+ X = torch.stack((p3, p4)).squeeze().detach().numpy()
40
+ X_pts = pts.squeeze().detach().numpy()
41
 
42
+ fig, ax = plt.subplots()
43
+ ax.plot(X_pts[:, 0], X_pts[:, 1], 'ro')
44
+ ax.plot(X[:, 0], X[:, 1])
45
+ ax.set_xlim(X_pts[:, 0].min() - 1, X_pts[:, 0].max() + 1)
46
+ ax.set_ylim(X_pts[:, 1].min() - 1, X_pts[:, 1].max() + 1)
47
  return fig
48
 
49
  inputs = [
50
+ gr.Slider(0.0, 10.0, value=0.0, label="Point 1 X"),
51
+ gr.Slider(0.0, 10.0, value=0.0, label="Point 1 Y"),
52
+ gr.Slider(0.0, 10.0, value=10.0, label="Point 2 X"),
53
+ gr.Slider(0.0, 10.0, value=10.0, label="Point 2 Y"),
54
  ]
55
+
56
  outputs = gr.Plot()
57
 
58
  examples = [
59
+ [0.0, 0.0, 10.0, 10.0],
60
  ]
61
 
62
  title = 'Line Fitting'
 
70
  theme='huggingface',
71
  live=True,
72
  examples=examples,
 
73
  )
74
+
75
+ demo.launch()