Spaces:
Runtime error
Runtime error
neverix
commited on
Commit
·
a40667a
1
Parent(s):
2b04aed
Fix bug? (?)
Browse files- data_loader.py +2 -2
data_loader.py
CHANGED
@@ -226,7 +226,7 @@ class FileDataset(Dataset):
|
|
226 |
if "labels" in sample:
|
227 |
# return UDP as 4chn XYZV float tensor
|
228 |
sample["labels"] = torch.from_numpy(
|
229 |
-
sample["labels"].transpose((2, 0, 1)))
|
230 |
assert (sample["labels"].dtype == torch.float32)
|
231 |
|
232 |
if "image_np" in sample:
|
@@ -270,4 +270,4 @@ class FileDataset(Dataset):
|
|
270 |
"character_masks": character_masks
|
271 |
})
|
272 |
# do not make fake labels in inference
|
273 |
-
return sample
|
|
|
226 |
if "labels" in sample:
|
227 |
# return UDP as 4chn XYZV float tensor
|
228 |
sample["labels"] = torch.from_numpy(
|
229 |
+
sample["labels"].transpose((2, 0, 1)).astype(np.float32))
|
230 |
assert (sample["labels"].dtype == torch.float32)
|
231 |
|
232 |
if "image_np" in sample:
|
|
|
270 |
"character_masks": character_masks
|
271 |
})
|
272 |
# do not make fake labels in inference
|
273 |
+
return sample
|