3v324v23 commited on
Commit
459e9e9
1 Parent(s): 56fc405
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -46,6 +46,11 @@ RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000,
46
 
47
  obj_dict = {}
48
 
 
 
 
 
 
49
  def trans(x, y, z, length):
50
  w = h = length
51
  x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
@@ -105,6 +110,7 @@ def objs_to_canvas(lst, length=256, scale = 2.6):
105
  def predict_local_view(lst):
106
  canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None]
107
  bevs = canvas[..., 0: 0+256]
 
108
  gen = G(code, RT, bevs)
109
  rgb = gen['gen_output']['image'][0] * .5 + .5
110
  return to_pil(rgb)
 
46
 
47
  obj_dict = {}
48
 
49
+ # init
50
+ fake_bevs = torch.zeros([1, 14, 256, 256], device='cuda').float()
51
+ _ = G(code, RT, fake_bevs)
52
+
53
+
54
  def trans(x, y, z, length):
55
  w = h = length
56
  x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
 
110
  def predict_local_view(lst):
111
  canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None]
112
  bevs = canvas[..., 0: 0+256]
113
+ print(code.shape, RT.shape, bevs.shape)
114
  gen = G(code, RT, bevs)
115
  rgb = gen['gen_output']['image'][0] * .5 + .5
116
  return to_pil(rgb)