Niv Sardi commited on
Commit
b1d65c2
1 Parent(s): 1732876

split dataset into train test val

Browse files
Files changed (3) hide show
  1. python/common/defaults.py +1 -0
  2. python/split.py +43 -0
  3. run.sh +2 -0
python/common/defaults.py CHANGED
@@ -26,3 +26,4 @@ AUGMENTED_LABELS_PATH = D('AUGMENTED_LABELS_PATH', f'{AUGMENTED_DATA_PATH}/label
26
  AUGMENTED_IMAGES_PATH = D('AUGMENTED_IMAGES_PATH', f'{AUGMENTED_DATA_PATH}/images')
27
 
28
  MAIN_CSV_PATH = D('MAIN_CSV_PATH', f'{DATA_PATH}/entities.csv')
 
 
26
  AUGMENTED_IMAGES_PATH = D('AUGMENTED_IMAGES_PATH', f'{AUGMENTED_DATA_PATH}/images')
27
 
28
  MAIN_CSV_PATH = D('MAIN_CSV_PATH', f'{DATA_PATH}/entities.csv')
29
+ SPLIT_DATA_PATH = D('SPLIT_DATA_PATH', f'{DATA_PATH}/split')
python/split.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ import os
3
+ import math
4
+ from common import defaults, mkdir
5
+
6
+ if __name__ == '__main__':
7
+ import argparse
8
+ parser = argparse.ArgumentParser(description='splits a yolo dataset between different data partitions')
9
+ parser.add_argument('datapath', metavar='datapath', type=str,
10
+ help='csv file', default=defaults.SQUARES_DATA_PATH)
11
+ parser.add_argument('--partitions', metavar='partitions', type=str, nargs='+',
12
+ help='data path', default=['train:0.8', 'val:0.1', 'test:0.1'])
13
+ parser.add_argument('--dest', metavar='dest', type=str,
14
+ help='dest path', default=defaults.SPLIT_DATA_PATH)
15
+
16
+ args = parser.parse_args()
17
+
18
+ def image_to_label(i):
19
+ l = i.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
20
+ if os.path.exists(l):
21
+ return l
22
+ return None
23
+
24
+ images = [d for d in os.scandir(os.path.join(args.datapath, 'images'))]
25
+
26
+ np = -1
27
+ for d,r in [a.split(':') for a in args.partitions]:
28
+ p = np + 1
29
+ np = min(p + math.floor(len(images)*float(r)), len(images))
30
+
31
+ cpi = os.path.join(args.dest, d, 'images')
32
+ cpl = os.path.join(args.dest, d, 'labels')
33
+ rpi = os.path.relpath(os.path.join(args.datapath, 'images'), cpi)
34
+ rpl = os.path.relpath(os.path.join(args.datapath, 'labels'), cpl)
35
+
36
+ mkdir.make_dirs([cpi, cpl])
37
+ print( f'{d:6s} [ {p:6d}, {np:6d} ] ({np-p:6d}:{(np-p)/len(images):0.2f} )')
38
+ for si in images[p:np]:
39
+ l = image_to_label(si.path)
40
+ os.symlink(os.path.join(rpi, si.name), os.path.join(cpi, si.name))
41
+ if l:
42
+ nl = os.path.basename(l)
43
+ os.symlink(os.path.join(rpl, nl), os.path.join(cpl, nl))
run.sh CHANGED
@@ -15,5 +15,7 @@ echo "✨ augmenting data"
15
  ${PY} ./python/augment.py
16
  echo "🖼 croping augmented data"
17
  ${PY} ./python/crop.py ./data/augmented/images
 
 
18
  echo "🧠 train model"
19
  sh train.sh
 
15
  ${PY} ./python/augment.py
16
  echo "🖼 croping augmented data"
17
  ${PY} ./python/crop.py ./data/augmented/images
18
+ echo "✂ split dataset into train, val and test groups"
19
+ ${PY} ./python/split.py ./data/squares/
20
  echo "🧠 train model"
21
  sh train.sh