import os import os.path as op import gc import json from typing import List import logging try: from .blob_storage import BlobStorage, disk_usage except: class BlobStorage: pass def generate_lineidx(filein: str, idxout: str) -> None: idxout_tmp = idxout + '.tmp' with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: fsize = os.fstat(tsvin.fileno()).st_size fpos = 0 while fpos != fsize: tsvout.write(str(fpos) + "\n") tsvin.readline() fpos = tsvin.tell() os.rename(idxout_tmp, idxout) def read_to_character(fp, c): result = [] while True: s = fp.read(32) assert s != '' if c in s: result.append(s[: s.index(c)]) break else: result.append(s) return ''.join(result) class TSVFile(object): def __init__(self, tsv_file: str, if_generate_lineidx: bool = False, lineidx: str = None, class_selector: List[str] = None, blob_storage: BlobStorage = None): self.tsv_file = tsv_file self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \ if not lineidx else lineidx self.linelist = op.splitext(tsv_file)[0] + '.linelist' self.chunks = op.splitext(tsv_file)[0] + '.chunks' self._fp = None self._lineidx = None self._sample_indices = None self._class_boundaries = None self._class_selector = class_selector self._blob_storage = blob_storage self._len = None # the process always keeps the process which opens the file. # If the pid is not equal to the currrent pid, we will re-open the file. self.pid = None # generate lineidx if not exist if not op.isfile(self.lineidx) and if_generate_lineidx: generate_lineidx(self.tsv_file, self.lineidx) def __del__(self): self.gcidx() if self._fp: self._fp.close() # physically remove the tsv file if it is retrieved by BlobStorage if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file): try: original_usage = disk_usage('/') os.remove(self.tsv_file) logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" % (self.tsv_file, original_usage, disk_usage('/') * 100)) except: # Known issue: multiple threads attempting to delete the file will raise a FileNotFound error. # TODO: try Threadling.Lock to better handle the race condition pass def __str__(self): return "TSVFile(tsv_file='{}')".format(self.tsv_file) def __repr__(self): return str(self) def gcidx(self): logging.debug('Run gc collect') self._lineidx = None self._sample_indices = None #self._class_boundaries = None return gc.collect() def get_class_boundaries(self): return self._class_boundaries def num_rows(self, gcf=False): if (self._len is None): self._ensure_lineidx_loaded() retval = len(self._sample_indices) if (gcf): self.gcidx() self._len = retval return self._len def seek(self, idx: int): self._ensure_tsv_opened() self._ensure_lineidx_loaded() try: pos = self._lineidx[self._sample_indices[idx]] except: logging.info('=> {}-{}'.format(self.tsv_file, idx)) raise self._fp.seek(pos) return [s.strip() for s in self._fp.readline().split('\t')] def seek_first_column(self, idx: int): self._ensure_tsv_opened() self._ensure_lineidx_loaded() pos = self._lineidx[idx] self._fp.seek(pos) return read_to_character(self._fp, '\t') def get_key(self, idx: int): return self.seek_first_column(idx) def __getitem__(self, index: int): return self.seek(index) def __len__(self): return self.num_rows() def _ensure_lineidx_loaded(self): if self._lineidx is None: logging.debug('=> loading lineidx: {}'.format(self.lineidx)) with open(self.lineidx, 'r') as fp: lines = fp.readlines() lines = [line.strip() for line in lines] self._lineidx = [int(line) for line in lines] # read the line list if exists linelist = None if op.isfile(self.linelist): with open(self.linelist, 'r') as fp: linelist = sorted( [ int(line.strip()) for line in fp.readlines() ] ) if op.isfile(self.chunks): self._sample_indices = [] self._class_boundaries = [] class_boundaries = json.load(open(self.chunks, 'r')) for class_name, boundary in class_boundaries.items(): start = len(self._sample_indices) if class_name in self._class_selector: for idx in range(boundary[0], boundary[1] + 1): # NOTE: potentially slow when linelist is long, try to speed it up if linelist and idx not in linelist: continue self._sample_indices.append(idx) end = len(self._sample_indices) self._class_boundaries.append((start, end)) else: if linelist: self._sample_indices = linelist else: self._sample_indices = list(range(len(self._lineidx))) def _ensure_tsv_opened(self): if self._fp is None: if self._blob_storage: self._fp = self._blob_storage.open(self.tsv_file) else: self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() if self.pid != os.getpid(): logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file)) self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() class TSVWriter(object): def __init__(self, tsv_file): self.tsv_file = tsv_file self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' self.tsv_file_tmp = self.tsv_file + '.tmp' self.lineidx_file_tmp = self.lineidx_file + '.tmp' self.tsv_fp = open(self.tsv_file_tmp, 'w') self.lineidx_fp = open(self.lineidx_file_tmp, 'w') self.idx = 0 def write(self, values, sep='\t'): v = '{0}\n'.format(sep.join(map(str, values))) self.tsv_fp.write(v) self.lineidx_fp.write(str(self.idx) + '\n') self.idx = self.idx + len(v) def close(self): self.tsv_fp.close() self.lineidx_fp.close() os.rename(self.tsv_file_tmp, self.tsv_file) os.rename(self.lineidx_file_tmp, self.lineidx_file)