Spaces:
Running
Running
import struct | |
from enum import IntEnum | |
from fsspec.spec import AbstractBufferedFile | |
from typing import Any, Iterator, NamedTuple | |
class GGUFValueType(IntEnum): | |
UINT8 = 0 | |
INT8 = 1 | |
UINT16 = 2 | |
INT16 = 3 | |
UINT32 = 4 | |
INT32 = 5 | |
FLOAT32 = 6 | |
BOOL = 7 | |
STRING = 8 | |
ARRAY = 9 | |
UINT64 = 10 | |
INT64 = 11 | |
FLOAT64 = 12 | |
standard_metadata = { | |
"general.type": (GGUFValueType.STRING, "model"), | |
"general.architecture": (GGUFValueType.STRING, "llama"), | |
"general.quantization_version": (GGUFValueType.UINT32, 2), | |
"general.alignment": (GGUFValueType.UINT32, 32), | |
"general.file_type": (GGUFValueType.UINT32, 0), | |
"general.name": (GGUFValueType.STRING, ""), | |
"general.author": (GGUFValueType.STRING, ""), | |
"general.version": (GGUFValueType.STRING, ""), | |
"general.organization": (GGUFValueType.STRING, ""), | |
"general.finetune": (GGUFValueType.STRING, ""), | |
"general.basename": (GGUFValueType.STRING, ""), | |
"general.description": (GGUFValueType.STRING, ""), | |
"general.quantized_by": (GGUFValueType.STRING, ""), | |
"general.size_label": (GGUFValueType.STRING, ""), | |
"general.license": (GGUFValueType.STRING, ""), | |
"general.license.name": (GGUFValueType.STRING, ""), | |
"general.license.link": (GGUFValueType.STRING, ""), | |
"general.url": (GGUFValueType.STRING, ""), | |
"general.doi": (GGUFValueType.STRING, ""), | |
"general.uuid": (GGUFValueType.STRING, ""), | |
"general.repo_url": (GGUFValueType.STRING, ""), | |
"general.source.url": (GGUFValueType.STRING, ""), | |
"general.source.doi": (GGUFValueType.STRING, ""), | |
"general.source.uuid": (GGUFValueType.STRING, ""), | |
"general.source.repo_url": (GGUFValueType.STRING, ""), | |
"general.tags": (GGUFValueType.STRING, []), | |
"general.languages": (GGUFValueType.STRING, []), | |
"general.datasets": (GGUFValueType.STRING, []), | |
"split.no": (GGUFValueType.UINT16, 0), | |
"split.count": (GGUFValueType.UINT16, 0), | |
"split.tensors.count": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.model": (GGUFValueType.STRING, "gpt2"), | |
"tokenizer.ggml.pre": (GGUFValueType.STRING, "llama-bpe"), | |
"tokenizer.ggml.tokens": (GGUFValueType.STRING, []), | |
"tokenizer.ggml.token_type": (GGUFValueType.INT32, []), | |
"tokenizer.ggml.scores": (GGUFValueType.FLOAT32, []), | |
"tokenizer.ggml.merges": (GGUFValueType.STRING, []), | |
"tokenizer.ggml.bos_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.eos_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.unknown_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.seperator_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.padding_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.cls_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.mask_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.add_bos_token": (GGUFValueType.BOOL, False), | |
"tokenizer.ggml.add_eos_token": (GGUFValueType.BOOL, False), | |
"tokenizer.ggml.add_space_prefix": (GGUFValueType.BOOL, False), | |
"tokenizer.ggml.remove_extra_whitespaces": (GGUFValueType.BOOL, False), | |
"tokenizer.chat_template": (GGUFValueType.STRING, ""), | |
"tokenizer.chat_template.rag": (GGUFValueType.STRING, ""), | |
"tokenizer.chat_template.tool_use": (GGUFValueType.STRING, ""), | |
"tokenizer.chat_templates": (GGUFValueType.STRING, []), | |
"tokenizer.ggml.prefix_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.suffix_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.middle_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.eot_token_id": (GGUFValueType.UINT32, 0), | |
"tokenizer.ggml.eom_token_id": (GGUFValueType.UINT32, 0), | |
"quantize.imatrix.file": (GGUFValueType.STRING, ""), | |
"quantize.imatrix.dataset": (GGUFValueType.STRING, ""), | |
"quantize.imatrix.entries_count": (GGUFValueType.INT32, 0), | |
"quantize.imatrix.chunks_count": (GGUFValueType.INT32, 0), | |
} | |
gguf_scalar_size: dict[GGUFValueType, int] = { | |
GGUFValueType.UINT8: 1, | |
GGUFValueType.INT8: 1, | |
GGUFValueType.UINT16: 2, | |
GGUFValueType.INT16: 2, | |
GGUFValueType.UINT32: 4, | |
GGUFValueType.INT32: 4, | |
GGUFValueType.FLOAT32: 4, | |
GGUFValueType.BOOL: 1, | |
GGUFValueType.UINT64: 8, | |
GGUFValueType.INT64: 8, | |
GGUFValueType.FLOAT64: 8, | |
} | |
gguf_scalar_pack: dict[GGUFValueType, str] = { | |
GGUFValueType.UINT8: "B", | |
GGUFValueType.INT8: "b", | |
GGUFValueType.UINT16: "H", | |
GGUFValueType.INT16: "h", | |
GGUFValueType.UINT32: "I", | |
GGUFValueType.INT32: "i", | |
GGUFValueType.FLOAT32: "f", | |
GGUFValueType.BOOL: "?", | |
GGUFValueType.UINT64: "Q", | |
GGUFValueType.INT64: "q", | |
GGUFValueType.FLOAT64: "d", | |
} | |
class GGUFData(NamedTuple): | |
type: GGUFValueType | None | |
value: Any | |
data: bytes | |
class HuggingGGUFstream: | |
fp: AbstractBufferedFile | |
header: dict[str, GGUFData] | |
metadata: dict[str, GGUFData] | |
endian: str | |
metaend: int | |
filesize: int | |
def __init__( | |
self, | |
fp: AbstractBufferedFile, | |
) -> None: | |
self.fp = fp | |
self.header = {} | |
self.metadata = {} | |
self.endian = '<' | |
self.metaend = 0 | |
self.filesize = fp.details.get('size') | |
if (data := self.fp.read(4)) != b'GGUF': | |
raise TypeError('File is not a GGUF') | |
self.header['magic'] = GGUFData( | |
type = None, | |
value = None, | |
data = data, | |
) | |
data = self._read_field(GGUFValueType.UINT32) | |
if data.value != 3: | |
if data.value == 3 << 24: | |
data.value = 3 | |
self.endian = '>' | |
else: | |
raise TypeError('Unsupported GGUF version') | |
self.header['version'] = data | |
data = self._read_field(GGUFValueType.UINT64) | |
self.header['tensors'] = data | |
data = self._read_field(GGUFValueType.UINT64) | |
self.header['metadata'] = data | |
def _unpack_field( | |
self, | |
buffer: bytes, | |
field_type: GGUFValueType, | |
repeat: int = 1, | |
) -> Any: | |
value = struct.unpack(f'{self.endian}{repeat}{gguf_scalar_pack.get(field_type)}', buffer) | |
return value[0] if repeat == 1 else value | |
def _pack_field( | |
self, | |
field_type: GGUFValueType, | |
*values, | |
) -> bytes: | |
return struct.pack(f'{self.endian}{len(values)}{gguf_scalar_pack.get(field_type)}', *values) | |
def _pack_value( | |
self, | |
val_type: GGUFValueType, | |
value: Any, | |
) -> bytes: | |
if isinstance(value, list): | |
data = self._pack_field(GGUFValueType.UINT32, val_type) | |
data += self._pack_field(GGUFValueType.UINT64, len(value)) | |
if val_type == GGUFValueType.ARRAY: | |
raise TypeError('Array of arrays currently unsupported') | |
elif val_type == GGUFValueType.STRING: | |
if isinstance(value, list): | |
for v in value: | |
buf = str(v).encode('utf-8') | |
data += self._pack_field(GGUFValueType.UINT64, len(buf)) | |
data += buf | |
else: | |
buf = str(value).encode('utf-8') | |
data = self._pack_field(GGUFValueType.UINT64, len(buf)) | |
data += buf | |
elif val_type in gguf_scalar_pack: | |
if isinstance(value, list): | |
data += self._pack_field(val_type, *value) | |
else: | |
data = self._pack_field(val_type, value) | |
else: | |
raise TypeError('Unknown metadata type') | |
return data | |
def _read_field( | |
self, | |
field_type: GGUFValueType, | |
repeat: int = 1, | |
) -> GGUFData: | |
data = self.fp.read(gguf_scalar_size.get(field_type) * repeat) | |
value = self._unpack_field(data, field_type, repeat = repeat) | |
return GGUFData( | |
type = field_type, | |
value = value, | |
data = data, | |
) | |
def _read_value( | |
self, | |
val_type: GGUFValueType, | |
) -> GGUFData: | |
if val_type == GGUFValueType.ARRAY: | |
data = self._read_field(GGUFValueType.UINT32) | |
val_len = self._read_field(GGUFValueType.UINT64) | |
if data.value in gguf_scalar_pack: | |
val = self._read_field(data.value, repeat = val_len.value) | |
data = GGUFData( | |
type = val.type, | |
value = list(val.value), | |
data = data.data + val_len.data + val.data, | |
) | |
else: | |
v = [] | |
d = [data.data, val_len.data] | |
for _ in range(val_len.value): | |
val = self._read_value(data.value) | |
d.append(val.data) | |
v.append(val.value) | |
data = GGUFData( | |
type = data.value, | |
value = v, | |
data = b''.join(d), | |
) | |
elif val_type == GGUFValueType.STRING: | |
data = self._read_field(GGUFValueType.UINT64) | |
val = self.fp.read(data.value) | |
data = GGUFData( | |
type = val_type, | |
value = val.decode('utf-8'), | |
data = data.data + val, | |
) | |
elif val_type in gguf_scalar_pack: | |
data = self._read_field(val_type) | |
else: | |
raise TypeError('Unknown metadata type') | |
return data | |
def _update_metacount( | |
self, | |
) -> None: | |
old_count = self.header['metadata'] | |
new_count = len(self.metadata) | |
self.header['metadata'] = GGUFData( | |
type = old_count.type, | |
value = new_count, | |
data = self._pack_field(old_count.type, new_count), | |
) | |
def read_metadata( | |
self, | |
) -> Iterator[tuple[str, GGUFData]]: | |
if self.metadata: | |
for k, v in self.metadata.items(): | |
yield k, v | |
else: | |
num_metadata = self.header['metadata'].value | |
for _ in range(num_metadata): | |
key = self._read_value(GGUFValueType.STRING) | |
val_type = self._read_field(GGUFValueType.UINT32) | |
val = self._read_value(val_type.value) | |
self.metadata[key.value] = val = GGUFData( | |
type = val.type, | |
value = val.value, | |
data = key.data + val_type.data + val.data, | |
) | |
yield key.value, val | |
self.metaend = self.fp.loc | |
def add_metadata( | |
self, | |
key: str, | |
type: GGUFValueType, | |
value: Any, | |
) -> None: | |
data = self._pack_value(GGUFValueType.STRING, key) | |
data += self._pack_field(GGUFValueType.UINT32, GGUFValueType.ARRAY if isinstance(value, list) else type) | |
data += self._pack_value(type, value) | |
if (meta := self.metadata.get(key)): | |
self.filesize -= len(meta.data) | |
self.filesize += len(data) | |
self.metadata[key] = GGUFData( | |
type = type, | |
value = value, | |
data = data, | |
) | |
if not meta: | |
self._update_metacount() | |
def remove_metadata( | |
self, | |
key: str, | |
) -> None: | |
if (meta := self.metadata.get(key)): | |
del self.metadata[key] | |
self.filesize -= len(meta.data) | |
self._update_metacount() | |