ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
raw
history blame
7.59 kB
import pickle
from bisect import bisect
from copy import deepcopy
import numpy as np
import gzip
def int2bytes(i: int, *, signed: bool = False) -> bytes:
length = ((i + ((i * signed) < 0)).bit_length() + 7 + signed) // 8
return i.to_bytes(length, byteorder='little', signed=signed)
def bytes2int(b: bytes, *, signed: bool = False) -> int:
return int.from_bytes(b, byteorder='little', signed=signed)
def load_index_data(data_file):
index_data_size = bytes2int(data_file.read(32))
index_data = data_file.read(index_data_size)
index_data = pickle.loads(index_data)
data_offsets = deepcopy(index_data['offsets'])
id2pos = deepcopy(index_data.get('id2pos', {}))
meta = deepcopy(index_data.get('meta', {}))
return data_offsets, id2pos, meta
class IndexedDataset:
def __init__(self, path, unpickle=True):
self.path = path
self.root_data_file = open(f"{path}.data", 'rb', buffering=-1)
try:
self.byte_offsets, self.id2pos, self.meta = load_index_data(self.root_data_file)
self.data_files = [self.root_data_file]
except:
self.__init__old(path)
self.meta = {}
self.gzip = self.meta.get('gzip', False)
if 'chunk_begin' not in self.meta:
self.meta['chunk_begin'] = [0]
for i in range(len(self.meta['chunk_begin'][1:])):
self.data_files.append(open(f"{self.path}.{i + 1}.data", 'rb'))
self.unpickle = unpickle
def __init__old(self, path):
self.path = path
index_data = np.load(f"{path}.idx", allow_pickle=True).item()
self.byte_offsets = index_data['offsets']
self.id2pos = index_data.get('id2pos', {})
self.data_files = [open(f"{path}.data", 'rb', buffering=-1)]
def __getitem__(self, i):
if self.id2pos is not None and len(self.id2pos) > 0:
i = self.id2pos[i]
self.check_index(i)
# chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
# if chunk_id == 0:
# data_file = open(f"{self.path}.data", 'rb', buffering=-1)
# else:
# data_file = open(f"{self.path}.{chunk_id}.data", 'rb', buffering=-1)
# data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
# b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
# data_file.close()
chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
data_file = self.data_files[chunk_id]
data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
unpickle = self.unpickle
if unpickle:
if self.gzip:
b = gzip.decompress(b)
item = pickle.loads(b)
else:
item = b
return item
def __del__(self):
for data_file in self.data_files:
data_file.close()
def check_index(self, i):
if i < 0 or i >= len(self.byte_offsets) - 1:
raise IndexError('index out of range')
def __len__(self):
return len(self.byte_offsets) - 1
def __iter__(self):
self.iter_i = 0
return self
def __next__(self):
if self.iter_i == len(self):
raise StopIteration
else:
item = self[self.iter_i]
self.iter_i += 1
return item
class IndexedDatasetBuilder:
def __init__(self, path, append=False, max_size=1024 * 1024 * 1024 * 64,
default_idx_size=1024 * 1024 * 16, gzip=False):
self.path = self.root_path = path
self.default_idx_size = default_idx_size
if append:
self.data_file = open(f"{path}.data", 'r+b')
self.data_file.seek(0)
self.byte_offsets, self.id2pos, self.meta = load_index_data(self.data_file)
self.data_file.seek(0)
self.data_file.write(bytes(default_idx_size))
self.data_file.seek(self.byte_offsets[-1])
self.gzip = self.meta['gzip']
else:
self.data_file = open(f"{path}.data", 'wb')
self.data_file.seek(default_idx_size)
self.byte_offsets = [default_idx_size]
self.id2pos = {}
self.meta = {}
self.meta['chunk_begin'] = [0]
self.gzip = self.meta['gzip'] = gzip
self.root_data_file = self.data_file
self.max_size = max_size
self.data_chunk_id = 0
def add_item(self, item, id=None, use_pickle=True):
if self.byte_offsets[-1] > self.meta['chunk_begin'][-1] + self.max_size:
if self.data_file != self.root_data_file:
self.data_file.close()
self.data_chunk_id += 1
self.data_file = open(f"{self.path}.{self.data_chunk_id}.data", 'wb')
self.data_file.seek(0)
self.meta['chunk_begin'].append(self.byte_offsets[-1])
if not use_pickle:
s = item
else:
s = pickle.dumps(item)
if self.gzip:
s = gzip.compress(s, 1)
bytes = self.data_file.write(s)
if id is not None:
self.id2pos[id] = len(self.byte_offsets) - 1
self.byte_offsets.append(self.byte_offsets[-1] + bytes)
def finalize(self):
self.root_data_file.seek(0)
s = pickle.dumps({'offsets': self.byte_offsets, 'id2pos': self.id2pos, 'meta': self.meta})
assert len(s) < self.default_idx_size, (len(s), self.default_idx_size)
len_bytes = int2bytes(len(s))
self.root_data_file.write(len_bytes)
self.root_data_file.seek(32)
self.root_data_file.write(s)
self.root_data_file.close()
try:
self.data_file.close()
except:
pass
if __name__ == "__main__":
import random
from tqdm import tqdm
# builder = IndexedDatasetBuilder(ds_path, append=True)
# for i in tqdm(range(size)):
# builder.add_item(items[i], i + size)
# builder.finalize()
# ds = IndexedDataset(ds_path)
# for i in tqdm(range(1000)):
# idx = random.randint(size, 2 * size - 1)
# assert (ds[idx]['a'] == items[idx - size]['a']).all()
# idx = random.randint(0, size - 1)
# assert (ds[idx]['a'] == items[idx]['a']).all()
ds_path = '/tmp/indexed_ds_example'
size = 100
items = [{"a": np.random.normal(size=[10000, 10]),
"b": np.random.normal(size=[10000, 10])} for i in range(size)]
builder = IndexedDatasetBuilder(ds_path, max_size=1024 * 1024 * 40)
builder.meta['lengths'] = [1, 2, 3]
for i in tqdm(range(size)):
builder.add_item(pickle.dumps(items[i]), i, use_pickle=False)
builder.finalize()
ds = IndexedDataset(ds_path)
assert ds.meta['lengths'] == [1, 2, 3]
for i in tqdm(range(1000)):
idx = random.randint(0, size - 1)
assert (ds[idx]['a'] == items[idx]['a']).all()
# builder = IndexedDataset2Builder(ds_path, append=True)
# builder.meta['lengths'] = [1, 2, 3, 5, 6, 7]
# for i in tqdm(range(size)):
# builder.add_item(items[i], i + size)
# builder.finalize()
# ds = IndexedDataset2(ds_path)
# assert ds.meta['lengths'] == [1, 2, 3, 5, 6, 7]
# for i in tqdm(range(1000)):
# idx = random.randint(size, 2 * size - 1)
# assert (ds[idx]['a'] == items[idx - size]['a']).all()
# idx = random.randint(0, size - 1)
# assert (ds[idx]['a'] == items[idx]['a']).all()