gguf-editor / _hf_gguf.py
CISCai's picture
Initial version
e738e15 verified
raw
history blame
11.4 kB
import struct
from enum import IntEnum
from fsspec.spec import AbstractBufferedFile
from typing import Any, Iterator, NamedTuple
class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
standard_metadata = {
"general.type": (GGUFValueType.STRING, "model"),
"general.architecture": (GGUFValueType.STRING, "llama"),
"general.quantization_version": (GGUFValueType.UINT32, 2),
"general.alignment": (GGUFValueType.UINT32, 32),
"general.file_type": (GGUFValueType.UINT32, 0),
"general.name": (GGUFValueType.STRING, ""),
"general.author": (GGUFValueType.STRING, ""),
"general.version": (GGUFValueType.STRING, ""),
"general.organization": (GGUFValueType.STRING, ""),
"general.finetune": (GGUFValueType.STRING, ""),
"general.basename": (GGUFValueType.STRING, ""),
"general.description": (GGUFValueType.STRING, ""),
"general.quantized_by": (GGUFValueType.STRING, ""),
"general.size_label": (GGUFValueType.STRING, ""),
"general.license": (GGUFValueType.STRING, ""),
"general.license.name": (GGUFValueType.STRING, ""),
"general.license.link": (GGUFValueType.STRING, ""),
"general.url": (GGUFValueType.STRING, ""),
"general.doi": (GGUFValueType.STRING, ""),
"general.uuid": (GGUFValueType.STRING, ""),
"general.repo_url": (GGUFValueType.STRING, ""),
"general.source.url": (GGUFValueType.STRING, ""),
"general.source.doi": (GGUFValueType.STRING, ""),
"general.source.uuid": (GGUFValueType.STRING, ""),
"general.source.repo_url": (GGUFValueType.STRING, ""),
"general.tags": (GGUFValueType.STRING, []),
"general.languages": (GGUFValueType.STRING, []),
"general.datasets": (GGUFValueType.STRING, []),
"split.no": (GGUFValueType.UINT16, 0),
"split.count": (GGUFValueType.UINT16, 0),
"split.tensors.count": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.model": (GGUFValueType.STRING, "gpt2"),
"tokenizer.ggml.pre": (GGUFValueType.STRING, "llama-bpe"),
"tokenizer.ggml.tokens": (GGUFValueType.STRING, []),
"tokenizer.ggml.token_type": (GGUFValueType.INT32, []),
"tokenizer.ggml.scores": (GGUFValueType.FLOAT32, []),
"tokenizer.ggml.merges": (GGUFValueType.STRING, []),
"tokenizer.ggml.bos_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.eos_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.unknown_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.seperator_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.padding_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.cls_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.mask_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.add_bos_token": (GGUFValueType.BOOL, False),
"tokenizer.ggml.add_eos_token": (GGUFValueType.BOOL, False),
"tokenizer.ggml.add_space_prefix": (GGUFValueType.BOOL, False),
"tokenizer.ggml.remove_extra_whitespaces": (GGUFValueType.BOOL, False),
"tokenizer.chat_template": (GGUFValueType.STRING, ""),
"tokenizer.chat_template.rag": (GGUFValueType.STRING, ""),
"tokenizer.chat_template.tool_use": (GGUFValueType.STRING, ""),
"tokenizer.chat_templates": (GGUFValueType.STRING, []),
"tokenizer.ggml.prefix_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.suffix_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.middle_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.eot_token_id": (GGUFValueType.UINT32, 0),
"tokenizer.ggml.eom_token_id": (GGUFValueType.UINT32, 0),
"quantize.imatrix.file": (GGUFValueType.STRING, ""),
"quantize.imatrix.dataset": (GGUFValueType.STRING, ""),
"quantize.imatrix.entries_count": (GGUFValueType.INT32, 0),
"quantize.imatrix.chunks_count": (GGUFValueType.INT32, 0),
}
gguf_scalar_size: dict[GGUFValueType, int] = {
GGUFValueType.UINT8: 1,
GGUFValueType.INT8: 1,
GGUFValueType.UINT16: 2,
GGUFValueType.INT16: 2,
GGUFValueType.UINT32: 4,
GGUFValueType.INT32: 4,
GGUFValueType.FLOAT32: 4,
GGUFValueType.BOOL: 1,
GGUFValueType.UINT64: 8,
GGUFValueType.INT64: 8,
GGUFValueType.FLOAT64: 8,
}
gguf_scalar_pack: dict[GGUFValueType, str] = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
GGUFValueType.UINT16: "H",
GGUFValueType.INT16: "h",
GGUFValueType.UINT32: "I",
GGUFValueType.INT32: "i",
GGUFValueType.FLOAT32: "f",
GGUFValueType.BOOL: "?",
GGUFValueType.UINT64: "Q",
GGUFValueType.INT64: "q",
GGUFValueType.FLOAT64: "d",
}
class GGUFData(NamedTuple):
type: GGUFValueType | None
value: Any
data: bytes
class HuggingGGUFstream:
fp: AbstractBufferedFile
header: dict[str, GGUFData]
metadata: dict[str, GGUFData]
endian: str
metaend: int
filesize: int
def __init__(
self,
fp: AbstractBufferedFile,
) -> None:
self.fp = fp
self.header = {}
self.metadata = {}
self.endian = '<'
self.metaend = 0
self.filesize = fp.details.get('size')
if (data := self.fp.read(4)) != b'GGUF':
raise TypeError('File is not a GGUF')
self.header['magic'] = GGUFData(
type = None,
value = None,
data = data,
)
data = self._read_field(GGUFValueType.UINT32)
if data.value != 3:
if data.value == 3 << 24:
data.value = 3
self.endian = '>'
else:
raise TypeError('Unsupported GGUF version')
self.header['version'] = data
data = self._read_field(GGUFValueType.UINT64)
self.header['tensors'] = data
data = self._read_field(GGUFValueType.UINT64)
self.header['metadata'] = data
def _unpack_field(
self,
buffer: bytes,
field_type: GGUFValueType,
repeat: int = 1,
) -> Any:
value = struct.unpack(f'{self.endian}{repeat}{gguf_scalar_pack.get(field_type)}', buffer)
return value[0] if repeat == 1 else value
def _pack_field(
self,
field_type: GGUFValueType,
*values,
) -> bytes:
return struct.pack(f'{self.endian}{len(values)}{gguf_scalar_pack.get(field_type)}', *values)
def _pack_value(
self,
val_type: GGUFValueType,
value: Any,
) -> bytes:
if isinstance(value, list):
data = self._pack_field(GGUFValueType.UINT32, val_type)
data += self._pack_field(GGUFValueType.UINT64, len(value))
if val_type == GGUFValueType.ARRAY:
raise TypeError('Array of arrays currently unsupported')
elif val_type == GGUFValueType.STRING:
if isinstance(value, list):
for v in value:
buf = str(v).encode('utf-8')
data += self._pack_field(GGUFValueType.UINT64, len(buf))
data += buf
else:
buf = str(value).encode('utf-8')
data = self._pack_field(GGUFValueType.UINT64, len(buf))
data += buf
elif val_type in gguf_scalar_pack:
if isinstance(value, list):
data += self._pack_field(val_type, *value)
else:
data = self._pack_field(val_type, value)
else:
raise TypeError('Unknown metadata type')
return data
def _read_field(
self,
field_type: GGUFValueType,
repeat: int = 1,
) -> GGUFData:
data = self.fp.read(gguf_scalar_size.get(field_type) * repeat)
value = self._unpack_field(data, field_type, repeat = repeat)
return GGUFData(
type = field_type,
value = value,
data = data,
)
def _read_value(
self,
val_type: GGUFValueType,
) -> GGUFData:
if val_type == GGUFValueType.ARRAY:
data = self._read_field(GGUFValueType.UINT32)
val_len = self._read_field(GGUFValueType.UINT64)
if data.value in gguf_scalar_pack:
val = self._read_field(data.value, repeat = val_len.value)
data = GGUFData(
type = val.type,
value = list(val.value),
data = data.data + val_len.data + val.data,
)
else:
v = []
d = [data.data, val_len.data]
for _ in range(val_len.value):
val = self._read_value(data.value)
d.append(val.data)
v.append(val.value)
data = GGUFData(
type = data.value,
value = v,
data = b''.join(d),
)
elif val_type == GGUFValueType.STRING:
data = self._read_field(GGUFValueType.UINT64)
val = self.fp.read(data.value)
data = GGUFData(
type = val_type,
value = val.decode('utf-8'),
data = data.data + val,
)
elif val_type in gguf_scalar_pack:
data = self._read_field(val_type)
else:
raise TypeError('Unknown metadata type')
return data
def _update_metacount(
self,
) -> None:
old_count = self.header['metadata']
new_count = len(self.metadata)
self.header['metadata'] = GGUFData(
type = old_count.type,
value = new_count,
data = self._pack_field(old_count.type, new_count),
)
def read_metadata(
self,
) -> Iterator[tuple[str, GGUFData]]:
if self.metadata:
for k, v in self.metadata.items():
yield k, v
else:
num_metadata = self.header['metadata'].value
for _ in range(num_metadata):
key = self._read_value(GGUFValueType.STRING)
val_type = self._read_field(GGUFValueType.UINT32)
val = self._read_value(val_type.value)
self.metadata[key.value] = val = GGUFData(
type = val.type,
value = val.value,
data = key.data + val_type.data + val.data,
)
yield key.value, val
self.metaend = self.fp.loc
def add_metadata(
self,
key: str,
type: GGUFValueType,
value: Any,
) -> None:
data = self._pack_value(GGUFValueType.STRING, key)
data += self._pack_field(GGUFValueType.UINT32, GGUFValueType.ARRAY if isinstance(value, list) else type)
data += self._pack_value(type, value)
if (meta := self.metadata.get(key)):
self.filesize -= len(meta.data)
self.filesize += len(data)
self.metadata[key] = GGUFData(
type = type,
value = value,
data = data,
)
if not meta:
self._update_metacount()
def remove_metadata(
self,
key: str,
) -> None:
if (meta := self.metadata.get(key)):
del self.metadata[key]
self.filesize -= len(meta.data)
self._update_metacount()