import os import subprocess import sys import argparse from pathlib import Path from concurrent.futures import ( ProcessPoolExecutor, as_completed, ) from zipnn_compress_file import compress_file sys.path.append( os.path.abspath( os.path.join( os.path.dirname(__file__), ".." ) ) ) KB = 1024 MB = 1024 * 1024 GB = 1024 * 1024 * 1024 RED = "\033[91m" YELLOW = "\033[93m" GREEN = "\033[92m" RESET = "\033[0m" def check_and_install_zipnn(): try: import zipnn except ImportError: print("zipnn not found. Installing...") subprocess.check_call( [ sys.executable, "-m", "pip", "install", "zipnn", "--upgrade", ] ) import zipnn def parse_streaming_chunk_size( streaming_chunk_size, ): if str(streaming_chunk_size).isdigit(): final = int(streaming_chunk_size) else: size_value = int( streaming_chunk_size[:-2] ) size_unit = streaming_chunk_size[ -2 ].lower() if size_unit == "k": final = KB * size_value elif size_unit == "m": final = MB * size_value elif size_unit == "g": final = GB * size_value else: raise ValueError( f"Invalid size unit: {size_unit}. Use 'k', 'm', or 'g'." ) return final def replace_in_file(file_path: Path | str, old: str, new: str) -> None: """Given a file_path, replace all occurrences of `old` with `new` inpalce.""" with open(file_path, 'r') as file: file_data = file.read() file_data = file_data.replace(old, new) with open(file_path, 'w') as file: file.write(file_data) def compress_files_with_suffix( suffix, dtype="", streaming_chunk_size=1048576, path=".", delete=False, r=False, force=False, max_processes=1, hf_cache=False, model="", branch="main", ): import zipnn overwrite_first=True file_list = [] streaming_chunk_size = ( parse_streaming_chunk_size( streaming_chunk_size ) ) if model: if not hf_cache: raise ValueError( "Must specify --hf_cache when using --model" ) try: from huggingface_hub import scan_cache_dir except ImportError: raise ImportError( "huggingface_hub not found. Please pip install huggingface_hub." ) cache = scan_cache_dir() repo = next((repo for repo in cache.repos if repo.repo_id == model), None) if repo is not None: print(f"Found repo {model} in cache") # Get the latest revision path hash = '' try: with open(os.path.join(repo.repo_path, 'refs', branch), "r") as ref: hash = ref.read() except FileNotFoundError: raise FileNotFoundError(f"Branch {branch} not found in repo {model}") path = os.path.join(repo.repo_path, 'snapshots', hash) directories_to_search = ( os.walk(path) if r else [(path, [], os.listdir(path))] ) files_found = False for root, _, files in directories_to_search: for file_name in files: if file_name.endswith(suffix): compressed_path = ( file_name + ".znn" ) if not force and os.path.exists( compressed_path ): # if overwrite_first: overwrite_first=False user_input = ( input( f"Compressed files already exists; Would you like to overwrite them all (y/n)? " ) .strip() .lower() ) if user_input not in ( "y", "yes", ): print( f"No forced overwriting." ) else: print( f"Overwriting all compressed files." ) force=True # if not force: user_input = ( input( f"{compressed_path} already exists; overwrite (y/n)? " ) .strip() .lower() ) if user_input not in ( "y", "yes", ): print( f"Skipping {file_name}..." ) continue files_found = True full_path = os.path.join( root, file_name ) file_list.append(full_path) if file_list and hf_cache: try: from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME ) except ImportError: raise ImportError( "Transformers not found. Please pip install transformers." ) if os.path.exists(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)): print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") blob_name = os.path.join(path, os.readlink(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME))) replace_in_file( file_path=blob_name, old=f"{suffix}", new=f"{suffix}.znn" ) elif os.path.exists(os.path.join(path, WEIGHTS_INDEX_NAME)): print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") blob_name = os.path.join(path, os.readlink(os.path.join(path, WEIGHTS_INDEX_NAME))) replace_in_file( file_path=blob_name, old=f"{suffix}", new=f"{suffix}.znn" ) with ProcessPoolExecutor( max_workers=max_processes ) as executor: future_to_file = { executor.submit( compress_file, file, dtype, streaming_chunk_size, delete, True, hf_cache, ): file for file in file_list[:max_processes] } file_list = file_list[max_processes:] while future_to_file: for future in as_completed( future_to_file ): file = future_to_file.pop(future) try: future.result() except Exception as exc: print( f"{RED}File {file} generated an exception: {exc}{RESET}" ) if file_list: next_file = file_list.pop(0) future_to_file[ executor.submit( compress_file, next_file, dtype, streaming_chunk_size, delete, True, hf_cache, ) ] = next_file if not files_found: print( f"{RED}No files with the suffix '{suffix}' found.{RESET}" ) print(f"{GREEN}All files compressed{RESET}") if __name__ == "__main__": if len(sys.argv) < 2: print( "Usage: python compress_files.py " ) print( "Example: python compress_files.py 'safetensors'" ) sys.exit(1) parser = argparse.ArgumentParser( description="Enter a suffix to compress, (optional) dtype, (optional) streaming chunk size, (optional) path to files." ) parser.add_argument( "suffix", type=str, help="Specify the file suffix to compress all files with that suffix. If a single file name is provided, only that file will be compressed.", ) parser.add_argument( "--float32", action="store_true", help="A flag that triggers float32 compression", ) parser.add_argument( "--streaming_chunk_size", type=str, help="An optional streaming chunk size. The format is int (for size in Bytes) or int+KB/MB/GB. Default is 1MB", ) parser.add_argument( "--path", type=str, help="Path to files to compress", ) parser.add_argument( "--delete", action="store_true", help="A flag that triggers deletion of a single file instead of compression", ) parser.add_argument( "-r", action="store_true", help="A flag that triggers recursive search on all subdirectories", ) parser.add_argument( "--recursive", action="store_true", help="A flag that triggers recursive search on all subdirectories", ) parser.add_argument( "--force", action="store_true", help="A flag that forces overwriting when compressing.", ) parser.add_argument( "--max_processes", type=int, help="The amount of maximum processes.", ) parser.add_argument( "--hf_cache", action="store_true", help="A flag that indicates if the file is in the Hugging Face cache. Must either specify --model or --path to the model's snapshot cache.", ) parser.add_argument( "--model", type=str, help="Only when using --hf_cache, specify the model name or path. E.g. 'ibm-granite/granite-7b-instruct'", ) parser.add_argument( "--model_branch", type=str, default="main", help="Only when using --model, specify the model branch. Default is 'main'", ) args = parser.parse_args() optional_kwargs = {} if args.float32: optional_kwargs["dtype"] = 32 if args.streaming_chunk_size is not None: optional_kwargs[ "streaming_chunk_size" ] = args.streaming_chunk_size if args.path is not None: optional_kwargs["path"] = args.path if args.delete: optional_kwargs["delete"] = args.delete if args.r or args.recursive: optional_kwargs["r"] = args.r if args.force: optional_kwargs["force"] = args.force if args.max_processes: optional_kwargs["max_processes"] = ( args.max_processes ) if args.hf_cache: optional_kwargs["hf_cache"] = args.hf_cache if args.model: optional_kwargs["model"] = args.model if args.model_branch: optional_kwargs[ "branch" ] = args.model_branch check_and_install_zipnn() compress_files_with_suffix( args.suffix, **optional_kwargs )