Spaces:
Sleeping
Sleeping
#! /usr/bin/env python3 | |
# This is a Python port of the Rust reference implementation of BLAKE3: | |
# https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs | |
from __future__ import annotations | |
from dataclasses import dataclass | |
OUT_LEN = 32 | |
KEY_LEN = 32 | |
BLOCK_LEN = 64 | |
CHUNK_LEN = 1024 | |
CHUNK_START = 1 << 0 | |
CHUNK_END = 1 << 1 | |
PARENT = 1 << 2 | |
ROOT = 1 << 3 | |
KEYED_HASH = 1 << 4 | |
DERIVE_KEY_CONTEXT = 1 << 5 | |
DERIVE_KEY_MATERIAL = 1 << 6 | |
IV = [ | |
0x6A09E667, | |
0xBB67AE85, | |
0x3C6EF372, | |
0xA54FF53A, | |
0x510E527F, | |
0x9B05688C, | |
0x1F83D9AB, | |
0x5BE0CD19, | |
] | |
MSG_PERMUTATION = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8] | |
def mask32(x: int) -> int: | |
return x & 0xFFFFFFFF | |
def add32(x: int, y: int) -> int: | |
return mask32(x + y) | |
def rightrotate32(x: int, n: int) -> int: | |
return mask32(x << (32 - n)) | (x >> n) | |
# The mixing function, G, which mixes either a column or a diagonal. | |
def g(state: list[int], a: int, b: int, c: int, d: int, mx: int, my: int) -> None: | |
state[a] = add32(state[a], add32(state[b], mx)) | |
state[d] = rightrotate32(state[d] ^ state[a], 16) | |
state[c] = add32(state[c], state[d]) | |
state[b] = rightrotate32(state[b] ^ state[c], 12) | |
state[a] = add32(state[a], add32(state[b], my)) | |
state[d] = rightrotate32(state[d] ^ state[a], 8) | |
state[c] = add32(state[c], state[d]) | |
state[b] = rightrotate32(state[b] ^ state[c], 7) | |
def round(state: list[int], m: list[int]) -> None: | |
# Mix the columns. | |
g(state, 0, 4, 8, 12, m[0], m[1]) | |
g(state, 1, 5, 9, 13, m[2], m[3]) | |
g(state, 2, 6, 10, 14, m[4], m[5]) | |
g(state, 3, 7, 11, 15, m[6], m[7]) | |
# Mix the diagonals. | |
g(state, 0, 5, 10, 15, m[8], m[9]) | |
g(state, 1, 6, 11, 12, m[10], m[11]) | |
g(state, 2, 7, 8, 13, m[12], m[13]) | |
g(state, 3, 4, 9, 14, m[14], m[15]) | |
def permute(m: list[int]) -> None: | |
original = list(m) | |
for i in range(16): | |
m[i] = original[MSG_PERMUTATION[i]] | |
def compress( | |
chaining_value: list[int], | |
block_words: list[int], | |
counter: int, | |
block_len: int, | |
flags: int, | |
) -> list[int]: | |
state = [ | |
chaining_value[0], | |
chaining_value[1], | |
chaining_value[2], | |
chaining_value[3], | |
chaining_value[4], | |
chaining_value[5], | |
chaining_value[6], | |
chaining_value[7], | |
IV[0], | |
IV[1], | |
IV[2], | |
IV[3], | |
mask32(counter), | |
mask32(counter >> 32), | |
block_len, | |
flags, | |
] | |
assert len(block_words) == 16 | |
block = list(block_words) | |
round(state, block) # round 1 | |
permute(block) | |
round(state, block) # round 2 | |
permute(block) | |
round(state, block) # round 3 | |
permute(block) | |
round(state, block) # round 4 | |
permute(block) | |
round(state, block) # round 5 | |
permute(block) | |
round(state, block) # round 6 | |
permute(block) | |
round(state, block) # round 7 | |
for i in range(8): | |
state[i] ^= state[i + 8] | |
state[i + 8] ^= chaining_value[i] | |
return state | |
def words_from_little_endian_bytes(b: bytes) -> list[int]: | |
assert len(b) % 4 == 0 | |
return [int.from_bytes(b[i : i + 4], "little") for i in range(0, len(b), 4)] | |
# Each chunk or parent node can produce either an 8-word chaining value or, by | |
# setting the ROOT flag, any number of final output bytes. The Output struct | |
# captures the state just prior to choosing between those two possibilities. | |
class Output: | |
input_chaining_value: list[int] | |
block_words: list[int] | |
counter: int | |
block_len: int | |
flags: int | |
def chaining_value(self) -> list[int]: | |
return compress( | |
self.input_chaining_value, | |
self.block_words, | |
self.counter, | |
self.block_len, | |
self.flags, | |
)[:8] | |
def root_output_bytes(self, length: int) -> bytes: | |
output_bytes = bytearray() | |
i = 0 | |
while i < length: | |
words = compress( | |
self.input_chaining_value, | |
self.block_words, | |
i // 64, | |
self.block_len, | |
self.flags | ROOT, | |
) | |
# The output length might not be a multiple of 4. | |
for word in words: | |
word_bytes = word.to_bytes(4, "little") | |
take = min(len(word_bytes), length - i) | |
output_bytes.extend(word_bytes[:take]) | |
i += take | |
return output_bytes | |
class ChunkState: | |
chaining_value: list[int] | |
chunk_counter: int | |
block: bytearray | |
block_len: int | |
blocks_compressed: int | |
flags: int | |
def __init__(self, key_words: list[int], chunk_counter: int, flags: int) -> None: | |
self.chaining_value = key_words | |
self.chunk_counter = chunk_counter | |
self.block = bytearray(BLOCK_LEN) | |
self.block_len = 0 | |
self.blocks_compressed = 0 | |
self.flags = flags | |
def len(self) -> int: | |
return BLOCK_LEN * self.blocks_compressed + self.block_len | |
def start_flag(self) -> int: | |
if self.blocks_compressed == 0: | |
return CHUNK_START | |
else: | |
return 0 | |
def update(self, input_bytes: bytes) -> None: | |
while input_bytes: | |
# If the block buffer is full, compress it and clear it. More | |
# input_bytes is coming, so this compression is not CHUNK_END. | |
if self.block_len == BLOCK_LEN: | |
block_words = words_from_little_endian_bytes(self.block) | |
self.chaining_value = compress( | |
self.chaining_value, | |
block_words, | |
self.chunk_counter, | |
BLOCK_LEN, | |
self.flags | self.start_flag(), | |
)[:8] | |
self.blocks_compressed += 1 | |
self.block = bytearray(BLOCK_LEN) | |
self.block_len = 0 | |
# Copy input bytes into the block buffer. | |
want = BLOCK_LEN - self.block_len | |
take = min(want, len(input_bytes)) | |
self.block[self.block_len : self.block_len + take] = input_bytes[:take] | |
self.block_len += take | |
input_bytes = input_bytes[take:] | |
def output(self) -> Output: | |
block_words = words_from_little_endian_bytes(self.block) | |
return Output( | |
self.chaining_value, | |
block_words, | |
self.chunk_counter, | |
self.block_len, | |
self.flags | self.start_flag() | CHUNK_END, | |
) | |
def parent_output( | |
left_child_cv: list[int], | |
right_child_cv: list[int], | |
key_words: list[int], | |
flags: int, | |
) -> Output: | |
return Output( | |
key_words, left_child_cv + right_child_cv, 0, BLOCK_LEN, PARENT | flags | |
) | |
def parent_cv( | |
left_child_cv: list[int], | |
right_child_cv: list[int], | |
key_words: list[int], | |
flags: int, | |
) -> list[int]: | |
return parent_output( | |
left_child_cv, right_child_cv, key_words, flags | |
).chaining_value() | |
# An incremental hasher that can accept any number of writes. | |
class Hasher: | |
chunk_state: ChunkState | |
key_words: list[int] | |
cv_stack: list[list[int]] | |
flags: int | |
def _init(self, key_words: list[int], flags: int) -> None: | |
assert len(key_words) == 8 | |
self.chunk_state = ChunkState(key_words, 0, flags) | |
self.key_words = key_words | |
self.cv_stack = [] | |
self.flags = flags | |
# Construct a new `Hasher` for the regular hash function. | |
def __init__(self) -> None: | |
self._init(IV, 0) | |
# Construct a new `Hasher` for the keyed hash function. | |
def new_keyed(cls, key: bytes) -> Hasher: | |
keyed_hasher = cls() | |
key_words = words_from_little_endian_bytes(key) | |
keyed_hasher._init(key_words, KEYED_HASH) | |
return keyed_hasher | |
# Construct a new `Hasher` for the key derivation function. The context | |
# string should be hardcoded, globally unique, and application-specific. | |
def new_derive_key(cls, context: str) -> Hasher: | |
context_hasher = cls() | |
context_hasher._init(IV, DERIVE_KEY_CONTEXT) | |
context_hasher.update(context.encode("utf8")) | |
context_key = context_hasher.finalize(KEY_LEN) | |
context_key_words = words_from_little_endian_bytes(context_key) | |
derive_key_hasher = cls() | |
derive_key_hasher._init(context_key_words, DERIVE_KEY_MATERIAL) | |
return derive_key_hasher | |
# Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail. | |
def add_chunk_chaining_value(self, new_cv: list[int], total_chunks: int) -> None: | |
# This chunk might complete some subtrees. For each completed subtree, | |
# its left child will be the current top entry in the CV stack, and | |
# its right child will be the current value of `new_cv`. Pop each left | |
# child off the stack, merge it with `new_cv`, and overwrite `new_cv` | |
# with the result. After all these merges, push the final value of | |
# `new_cv` onto the stack. The number of completed subtrees is given | |
# by the number of trailing 0-bits in the new total number of chunks. | |
while total_chunks & 1 == 0: | |
new_cv = parent_cv(self.cv_stack.pop(), new_cv, self.key_words, self.flags) | |
total_chunks >>= 1 | |
self.cv_stack.append(new_cv) | |
# Add input to the hash state. This can be called any number of times. | |
def update(self, input_bytes: bytes) -> None: | |
while input_bytes: | |
# If the current chunk is complete, finalize it and reset the | |
# chunk state. More input is coming, so this chunk is not ROOT. | |
if self.chunk_state.len() == CHUNK_LEN: | |
chunk_cv = self.chunk_state.output().chaining_value() | |
total_chunks = self.chunk_state.chunk_counter + 1 | |
self.add_chunk_chaining_value(chunk_cv, total_chunks) | |
self.chunk_state = ChunkState(self.key_words, total_chunks, self.flags) | |
# Compress input bytes into the current chunk state. | |
want = CHUNK_LEN - self.chunk_state.len() | |
take = min(want, len(input_bytes)) | |
self.chunk_state.update(input_bytes[:take]) | |
input_bytes = input_bytes[take:] | |
# Finalize the hash and write any number of output bytes. | |
def finalize(self, length: int = OUT_LEN) -> bytes: | |
# Starting with the Output from the current chunk, compute all the | |
# parent chaining values along the right edge of the tree, until we | |
# have the root Output. | |
output = self.chunk_state.output() | |
parent_nodes_remaining = len(self.cv_stack) | |
while parent_nodes_remaining > 0: | |
parent_nodes_remaining -= 1 | |
output = parent_output( | |
self.cv_stack[parent_nodes_remaining], | |
output.chaining_value(), | |
self.key_words, | |
self.flags, | |
) | |
return output.root_output_bytes(length) | |
# If this file is executed directly, hash standard input. | |
if __name__ == "__main__": | |
import sys | |
hasher = Hasher() | |
while buf := sys.stdin.buffer.read(65536): | |
hasher.update(buf) | |
print(hasher.finalize().hex()) |