import struct from enum import IntEnum from fsspec.spec import AbstractBufferedFile from typing import Any, Iterator, NamedTuple class GGUFValueType(IntEnum): UINT8 = 0 INT8 = 1 UINT16 = 2 INT16 = 3 UINT32 = 4 INT32 = 5 FLOAT32 = 6 BOOL = 7 STRING = 8 ARRAY = 9 UINT64 = 10 INT64 = 11 FLOAT64 = 12 standard_metadata = { "general.type": (GGUFValueType.STRING, "model"), "general.architecture": (GGUFValueType.STRING, "llama"), "general.quantization_version": (GGUFValueType.UINT32, 2), "general.alignment": (GGUFValueType.UINT32, 32), "general.file_type": (GGUFValueType.UINT32, 0), "general.name": (GGUFValueType.STRING, ""), "general.author": (GGUFValueType.STRING, ""), "general.version": (GGUFValueType.STRING, ""), "general.organization": (GGUFValueType.STRING, ""), "general.finetune": (GGUFValueType.STRING, ""), "general.basename": (GGUFValueType.STRING, ""), "general.description": (GGUFValueType.STRING, ""), "general.quantized_by": (GGUFValueType.STRING, ""), "general.size_label": (GGUFValueType.STRING, ""), "general.license": (GGUFValueType.STRING, ""), "general.license.name": (GGUFValueType.STRING, ""), "general.license.link": (GGUFValueType.STRING, ""), "general.url": (GGUFValueType.STRING, ""), "general.doi": (GGUFValueType.STRING, ""), "general.uuid": (GGUFValueType.STRING, ""), "general.repo_url": (GGUFValueType.STRING, ""), "general.source.url": (GGUFValueType.STRING, ""), "general.source.doi": (GGUFValueType.STRING, ""), "general.source.uuid": (GGUFValueType.STRING, ""), "general.source.repo_url": (GGUFValueType.STRING, ""), "general.tags": (GGUFValueType.STRING, []), "general.languages": (GGUFValueType.STRING, []), "general.datasets": (GGUFValueType.STRING, []), "split.no": (GGUFValueType.UINT16, 0), "split.count": (GGUFValueType.UINT16, 0), "split.tensors.count": (GGUFValueType.UINT32, 0), "tokenizer.ggml.model": (GGUFValueType.STRING, "gpt2"), "tokenizer.ggml.pre": (GGUFValueType.STRING, "llama-bpe"), "tokenizer.ggml.tokens": (GGUFValueType.STRING, []), "tokenizer.ggml.token_type": (GGUFValueType.INT32, []), "tokenizer.ggml.scores": (GGUFValueType.FLOAT32, []), "tokenizer.ggml.merges": (GGUFValueType.STRING, []), "tokenizer.ggml.bos_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.eos_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.unknown_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.seperator_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.padding_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.cls_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.mask_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.add_bos_token": (GGUFValueType.BOOL, False), "tokenizer.ggml.add_eos_token": (GGUFValueType.BOOL, False), "tokenizer.ggml.add_space_prefix": (GGUFValueType.BOOL, False), "tokenizer.ggml.remove_extra_whitespaces": (GGUFValueType.BOOL, False), "tokenizer.chat_template": (GGUFValueType.STRING, ""), "tokenizer.chat_template.rag": (GGUFValueType.STRING, ""), "tokenizer.chat_template.tool_use": (GGUFValueType.STRING, ""), "tokenizer.chat_templates": (GGUFValueType.STRING, []), "tokenizer.ggml.prefix_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.suffix_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.middle_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.eot_token_id": (GGUFValueType.UINT32, 0), "tokenizer.ggml.eom_token_id": (GGUFValueType.UINT32, 0), "quantize.imatrix.file": (GGUFValueType.STRING, ""), "quantize.imatrix.dataset": (GGUFValueType.STRING, ""), "quantize.imatrix.entries_count": (GGUFValueType.INT32, 0), "quantize.imatrix.chunks_count": (GGUFValueType.INT32, 0), } gguf_scalar_size: dict[GGUFValueType, int] = { GGUFValueType.UINT8: 1, GGUFValueType.INT8: 1, GGUFValueType.UINT16: 2, GGUFValueType.INT16: 2, GGUFValueType.UINT32: 4, GGUFValueType.INT32: 4, GGUFValueType.FLOAT32: 4, GGUFValueType.BOOL: 1, GGUFValueType.UINT64: 8, GGUFValueType.INT64: 8, GGUFValueType.FLOAT64: 8, } gguf_scalar_pack: dict[GGUFValueType, str] = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", GGUFValueType.UINT16: "H", GGUFValueType.INT16: "h", GGUFValueType.UINT32: "I", GGUFValueType.INT32: "i", GGUFValueType.FLOAT32: "f", GGUFValueType.BOOL: "?", GGUFValueType.UINT64: "Q", GGUFValueType.INT64: "q", GGUFValueType.FLOAT64: "d", } class GGUFData(NamedTuple): type: GGUFValueType | None value: Any data: bytes class HuggingGGUFstream: fp: AbstractBufferedFile header: dict[str, GGUFData] metadata: dict[str, GGUFData] endian: str metaend: int filesize: int def __init__( self, fp: AbstractBufferedFile, ) -> None: self.fp = fp self.header = {} self.metadata = {} self.endian = '<' self.metaend = 0 self.filesize = fp.details.get('size') if (data := self.fp.read(4)) != b'GGUF': raise TypeError('File is not a GGUF') self.header['magic'] = GGUFData( type = None, value = None, data = data, ) data = self._read_field(GGUFValueType.UINT32) if data.value != 3: if data.value == 3 << 24: data.value = 3 self.endian = '>' else: raise TypeError('Unsupported GGUF version') self.header['version'] = data data = self._read_field(GGUFValueType.UINT64) self.header['tensors'] = data data = self._read_field(GGUFValueType.UINT64) self.header['metadata'] = data def _unpack_field( self, buffer: bytes, field_type: GGUFValueType, repeat: int = 1, ) -> Any: value = struct.unpack(f'{self.endian}{repeat}{gguf_scalar_pack.get(field_type)}', buffer) return value[0] if repeat == 1 else value def _pack_field( self, field_type: GGUFValueType, *values, ) -> bytes: return struct.pack(f'{self.endian}{len(values)}{gguf_scalar_pack.get(field_type)}', *values) def _pack_value( self, val_type: GGUFValueType, value: Any, ) -> bytes: if isinstance(value, list): data = self._pack_field(GGUFValueType.UINT32, val_type) data += self._pack_field(GGUFValueType.UINT64, len(value)) if val_type == GGUFValueType.ARRAY: raise TypeError('Array of arrays currently unsupported') elif val_type == GGUFValueType.STRING: if isinstance(value, list): for v in value: buf = str(v).encode('utf-8') data += self._pack_field(GGUFValueType.UINT64, len(buf)) data += buf else: buf = str(value).encode('utf-8') data = self._pack_field(GGUFValueType.UINT64, len(buf)) data += buf elif val_type in gguf_scalar_pack: if isinstance(value, list): data += self._pack_field(val_type, *value) else: data = self._pack_field(val_type, value) else: raise TypeError('Unknown metadata type') return data def _read_field( self, field_type: GGUFValueType, repeat: int = 1, ) -> GGUFData: data = self.fp.read(gguf_scalar_size.get(field_type) * repeat) value = self._unpack_field(data, field_type, repeat = repeat) return GGUFData( type = field_type, value = value, data = data, ) def _read_value( self, val_type: GGUFValueType, ) -> GGUFData: if val_type == GGUFValueType.ARRAY: data = self._read_field(GGUFValueType.UINT32) val_len = self._read_field(GGUFValueType.UINT64) if data.value in gguf_scalar_pack: val = self._read_field(data.value, repeat = val_len.value) data = GGUFData( type = val.type, value = list(val.value), data = data.data + val_len.data + val.data, ) else: v = [] d = [data.data, val_len.data] for _ in range(val_len.value): val = self._read_value(data.value) d.append(val.data) v.append(val.value) data = GGUFData( type = data.value, value = v, data = b''.join(d), ) elif val_type == GGUFValueType.STRING: data = self._read_field(GGUFValueType.UINT64) val = self.fp.read(data.value) data = GGUFData( type = val_type, value = val.decode('utf-8'), data = data.data + val, ) elif val_type in gguf_scalar_pack: data = self._read_field(val_type) else: raise TypeError('Unknown metadata type') return data def _update_metacount( self, ) -> None: old_count = self.header['metadata'] new_count = len(self.metadata) self.header['metadata'] = GGUFData( type = old_count.type, value = new_count, data = self._pack_field(old_count.type, new_count), ) def read_metadata( self, ) -> Iterator[tuple[str, GGUFData]]: if self.metadata: for k, v in self.metadata.items(): yield k, v else: num_metadata = self.header['metadata'].value for _ in range(num_metadata): key = self._read_value(GGUFValueType.STRING) val_type = self._read_field(GGUFValueType.UINT32) val = self._read_value(val_type.value) self.metadata[key.value] = val = GGUFData( type = val.type, value = val.value, data = key.data + val_type.data + val.data, ) yield key.value, val self.metaend = self.fp.loc def add_metadata( self, key: str, type: GGUFValueType, value: Any, ) -> None: data = self._pack_value(GGUFValueType.STRING, key) data += self._pack_field(GGUFValueType.UINT32, GGUFValueType.ARRAY if isinstance(value, list) else type) data += self._pack_value(type, value) if (meta := self.metadata.get(key)): self.filesize -= len(meta.data) self.filesize += len(data) self.metadata[key] = GGUFData( type = type, value = value, data = data, ) if not meta: self._update_metacount() def remove_metadata( self, key: str, ) -> None: if (meta := self.metadata.get(key)): del self.metadata[key] self.filesize -= len(meta.data) self._update_metacount()