Spaces:
Build error
Build error
import os | |
import os.path as op | |
import json | |
# import logging | |
import base64 | |
import yaml | |
import errno | |
import io | |
import math | |
from PIL import Image, ImageDraw | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from .box_label_loader import LabelLoader | |
def load_linelist_file(linelist_file): | |
if linelist_file is not None: | |
line_list = [] | |
with open(linelist_file, 'r') as fp: | |
for i in fp: | |
line_list.append(int(i.strip())) | |
return line_list | |
def img_from_base64(imagestring): | |
try: | |
img = Image.open(io.BytesIO(base64.b64decode(imagestring))) | |
return img.convert('RGB') | |
except ValueError: | |
return None | |
def load_from_yaml_file(yaml_file): | |
with open(yaml_file, 'r') as fp: | |
return yaml.load(fp, Loader=yaml.CLoader) | |
def find_file_path_in_yaml(fname, root): | |
if fname is not None: | |
if op.isfile(fname): | |
return fname | |
elif op.isfile(op.join(root, fname)): | |
return op.join(root, fname) | |
else: | |
raise FileNotFoundError( | |
errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) | |
) | |
def create_lineidx(filein, idxout): | |
idxout_tmp = idxout + '.tmp' | |
with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: | |
fsize = os.fstat(tsvin.fileno()).st_size | |
fpos = 0 | |
while fpos != fsize: | |
tsvout.write(str(fpos) + "\n") | |
tsvin.readline() | |
fpos = tsvin.tell() | |
os.rename(idxout_tmp, idxout) | |
def read_to_character(fp, c): | |
result = [] | |
while True: | |
s = fp.read(32) | |
assert s != '' | |
if c in s: | |
result.append(s[: s.index(c)]) | |
break | |
else: | |
result.append(s) | |
return ''.join(result) | |
class TSVFile(object): | |
def __init__(self, tsv_file, generate_lineidx=False): | |
self.tsv_file = tsv_file | |
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' | |
self._fp = None | |
self._lineidx = None | |
# the process always keeps the process which opens the file. | |
# If the pid is not equal to the currrent pid, we will re-open the file. | |
self.pid = None | |
# generate lineidx if not exist | |
if not op.isfile(self.lineidx) and generate_lineidx: | |
create_lineidx(self.tsv_file, self.lineidx) | |
def __del__(self): | |
if self._fp: | |
self._fp.close() | |
def __str__(self): | |
return "TSVFile(tsv_file='{}')".format(self.tsv_file) | |
def __repr__(self): | |
return str(self) | |
def num_rows(self): | |
self._ensure_lineidx_loaded() | |
return len(self._lineidx) | |
def seek(self, idx): | |
self._ensure_tsv_opened() | |
self._ensure_lineidx_loaded() | |
try: | |
pos = self._lineidx[idx] | |
except: | |
# logging.info('{}-{}'.format(self.tsv_file, idx)) | |
raise | |
self._fp.seek(pos) | |
return [s.strip() for s in self._fp.readline().split('\t')] | |
def seek_first_column(self, idx): | |
self._ensure_tsv_opened() | |
self._ensure_lineidx_loaded() | |
pos = self._lineidx[idx] | |
self._fp.seek(pos) | |
return read_to_character(self._fp, '\t') | |
def get_key(self, idx): | |
return self.seek_first_column(idx) | |
def __getitem__(self, index): | |
return self.seek(index) | |
def __len__(self): | |
return self.num_rows() | |
def _ensure_lineidx_loaded(self): | |
if self._lineidx is None: | |
# logging.info('loading lineidx: {}'.format(self.lineidx)) | |
with open(self.lineidx, 'r') as fp: | |
self._lineidx = [int(i.strip()) for i in fp.readlines()] | |
def _ensure_tsv_opened(self): | |
if self._fp is None: | |
self._fp = open(self.tsv_file, 'r') | |
self.pid = os.getpid() | |
if self.pid != os.getpid(): | |
# logging.info('re-open {} because the process id changed'.format(self.tsv_file)) | |
self._fp = open(self.tsv_file, 'r') | |
self.pid = os.getpid() | |
class CompositeTSVFile(): | |
def __init__(self, file_list, seq_file, root='.'): | |
if isinstance(file_list, str): | |
self.file_list = load_list_file(file_list) | |
else: | |
assert isinstance(file_list, list) | |
self.file_list = file_list | |
self.seq_file = seq_file | |
self.root = root | |
self.initialized = False | |
self.initialize() | |
def get_key(self, index): | |
idx_source, idx_row = self.seq[index] | |
k = self.tsvs[idx_source].get_key(idx_row) | |
return '_'.join([self.file_list[idx_source], k]) | |
def num_rows(self): | |
return len(self.seq) | |
def __getitem__(self, index): | |
idx_source, idx_row = self.seq[index] | |
return self.tsvs[idx_source].seek(idx_row) | |
def __len__(self): | |
return len(self.seq) | |
def initialize(self): | |
''' | |
this function has to be called in init function if cache_policy is | |
enabled. Thus, let's always call it in init funciton to make it simple. | |
''' | |
if self.initialized: | |
return | |
self.seq = [] | |
with open(self.seq_file, 'r') as fp: | |
for line in fp: | |
parts = line.strip().split('\t') | |
self.seq.append([int(parts[0]), int(parts[1])]) | |
self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] | |
self.initialized = True | |
def load_list_file(fname): | |
with open(fname, 'r') as fp: | |
lines = fp.readlines() | |
result = [line.strip() for line in lines] | |
if len(result) > 0 and result[-1] == '': | |
result = result[:-1] | |
return result | |
class TSVDataset(object): | |
def __init__(self, img_file, label_file=None, hw_file=None, | |
linelist_file=None, imageid2idx_file=None): | |
"""Constructor. | |
Args: | |
img_file: Image file with image key and base64 encoded image str. | |
label_file: An optional label file with image key and label information. | |
A label_file is required for training and optional for testing. | |
hw_file: An optional file with image key and image height/width info. | |
linelist_file: An optional file with a list of line indexes to load samples. | |
It is useful to select a subset of samples or duplicate samples. | |
""" | |
self.img_file = img_file | |
self.label_file = label_file | |
self.hw_file = hw_file | |
self.linelist_file = linelist_file | |
self.img_tsv = TSVFile(img_file) | |
self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True) | |
self.hw_tsv = None if hw_file is None else TSVFile(hw_file) | |
self.line_list = load_linelist_file(linelist_file) | |
self.imageid2idx = None | |
if imageid2idx_file is not None: | |
self.imageid2idx = json.load(open(imageid2idx_file, 'r')) | |
self.transforms = None | |
def __len__(self): | |
if self.line_list is None: | |
if self.imageid2idx is not None: | |
assert self.label_tsv is not None, "label_tsv is None!!!" | |
return self.label_tsv.num_rows() | |
return self.img_tsv.num_rows() | |
else: | |
return len(self.line_list) | |
def __getitem__(self, idx): | |
img = self.get_image(idx) | |
img_size = img.size # w, h | |
annotations = self.get_annotations(idx) | |
# print(idx, annotations) | |
target = self.get_target_from_annotations(annotations, img_size, idx) | |
img, target = self.apply_transforms(img, target) | |
if self.transforms is None: | |
return img, target, idx, 1.0 | |
else: | |
new_img_size = img.shape[1:] | |
scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1])) | |
return img, target, idx, scale | |
def get_line_no(self, idx): | |
return idx if self.line_list is None else self.line_list[idx] | |
def get_image(self, idx): | |
line_no = self.get_line_no(idx) | |
if self.imageid2idx is not None: | |
assert self.label_tsv is not None, "label_tsv is None!!!" | |
row = self.label_tsv.seek(line_no) | |
annotations = json.loads(row[1]) | |
imageid = annotations["img_id"] | |
line_no = self.imageid2idx[imageid] | |
row = self.img_tsv.seek(line_no) | |
# use -1 to support old format with multiple columns. | |
img = img_from_base64(row[-1]) | |
return img | |
def get_annotations(self, idx): | |
line_no = self.get_line_no(idx) | |
if self.label_tsv is not None: | |
row = self.label_tsv.seek(line_no) | |
annotations = json.loads(row[1]) | |
return annotations | |
else: | |
return [] | |
def get_target_from_annotations(self, annotations, img_size, idx): | |
# This function will be overwritten by each dataset to | |
# decode the labels to specific formats for each task. | |
return annotations | |
def apply_transforms(self, image, target=None): | |
# This function will be overwritten by each dataset to | |
# apply transforms to image and targets. | |
return image, target | |
def get_img_info(self, idx): | |
if self.imageid2idx is not None: | |
assert self.label_tsv is not None, "label_tsv is None!!!" | |
line_no = self.get_line_no(idx) | |
row = self.label_tsv.seek(line_no) | |
annotations = json.loads(row[1]) | |
return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])} | |
if self.hw_tsv is not None: | |
line_no = self.get_line_no(idx) | |
row = self.hw_tsv.seek(line_no) | |
try: | |
# json string format with "height" and "width" being the keys | |
data = json.loads(row[1]) | |
if type(data) == list: | |
return data[0] | |
elif type(data) == dict: | |
return data | |
except ValueError: | |
# list of strings representing height and width in order | |
hw_str = row[1].split(' ') | |
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} | |
return hw_dict | |
def get_img_key(self, idx): | |
line_no = self.get_line_no(idx) | |
# based on the overhead of reading each row. | |
if self.imageid2idx is not None: | |
assert self.label_tsv is not None, "label_tsv is None!!!" | |
row = self.label_tsv.seek(line_no) | |
annotations = json.loads(row[1]) | |
return annotations["img_id"] | |
if self.hw_tsv: | |
return self.hw_tsv.seek(line_no)[0] | |
elif self.label_tsv: | |
return self.label_tsv.seek(line_no)[0] | |
else: | |
return self.img_tsv.seek(line_no)[0] | |
class TSVYamlDataset(TSVDataset): | |
""" TSVDataset taking a Yaml file for easy function call | |
""" | |
def __init__(self, yaml_file, root=None, replace_clean_label=False): | |
print("Reading {}".format(yaml_file)) | |
self.cfg = load_from_yaml_file(yaml_file) | |
if root: | |
self.root = root | |
else: | |
self.root = op.dirname(yaml_file) | |
img_file = find_file_path_in_yaml(self.cfg['img'], self.root) | |
label_file = find_file_path_in_yaml(self.cfg.get('label', None), | |
self.root) | |
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) | |
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), | |
self.root) | |
imageid2idx_file = find_file_path_in_yaml(self.cfg.get('imageid2idx', None), | |
self.root) | |
if replace_clean_label: | |
assert ("raw_label" in label_file) | |
label_file = label_file.replace("raw_label", "clean_label") | |
super(TSVYamlDataset, self).__init__( | |
img_file, label_file, hw_file, linelist_file, imageid2idx_file) | |
class ODTSVDataset(TSVYamlDataset): | |
""" | |
Generic TSV dataset format for Object Detection. | |
""" | |
def __init__(self, yaml_file, extra_fields=(), transforms=None, | |
is_load_label=True, **kwargs): | |
if yaml_file is None: | |
return | |
super(ODTSVDataset, self).__init__(yaml_file) | |
self.transforms = transforms | |
self.is_load_label = is_load_label | |
self.attribute_on = False | |
# self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False | |
if self.is_load_label: | |
# construct maps | |
jsondict_file = find_file_path_in_yaml( | |
self.cfg.get("labelmap", None), self.root | |
) | |
if jsondict_file is None: | |
jsondict_file = find_file_path_in_yaml( | |
self.cfg.get("jsondict", None), self.root | |
) | |
if "json" in jsondict_file: | |
jsondict = json.load(open(jsondict_file, 'r')) | |
if "label_to_idx" not in jsondict: | |
jsondict = {'label_to_idx': jsondict} | |
elif "tsv" in jsondict_file: | |
label_to_idx = {} | |
counter = 1 | |
with open(jsondict_file) as f: | |
for line in f: | |
label_to_idx[line.strip()] = counter | |
counter += 1 | |
jsondict = {'label_to_idx': label_to_idx} | |
else: | |
assert (0) | |
self.labelmap = {} | |
self.class_to_ind = jsondict['label_to_idx'] | |
self.class_to_ind['__background__'] = 0 | |
self.ind_to_class = {v: k for k, v in self.class_to_ind.items()} | |
self.labelmap['class_to_ind'] = self.class_to_ind | |
if self.attribute_on: | |
self.attribute_to_ind = jsondict['attribute_to_idx'] | |
self.attribute_to_ind['__no_attribute__'] = 0 | |
self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()} | |
self.labelmap['attribute_to_ind'] = self.attribute_to_ind | |
self.label_loader = LabelLoader( | |
labelmap=self.labelmap, | |
extra_fields=extra_fields, | |
) | |
def get_target_from_annotations(self, annotations, img_size, idx): | |
if isinstance(annotations, list): | |
annotations = {"objects": annotations} | |
if self.is_load_label: | |
return self.label_loader(annotations['objects'], img_size) | |
def apply_transforms(self, img, target=None): | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
return img, target | |