import argparse import json import os import shutil from diffusers.pipelines.stable_diffusion import safety_checker import torch from tempfile import TemporaryDirectory from typing import List, Optional from diffusers import StableDiffusionPipeline, ControlNetModel from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download from huggingface_hub.file_download import repo_folder_name def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]: info = api.model_info(model_id) filenames = set(s.rfilename for s in info.siblings) is_sd = "model_index.json" in filenames if is_sd: model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None) else: model = ControlNetModel.from_pretrained(model_id, from_flax=True) with TemporaryDirectory() as d: folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) os.makedirs(folder) model.save_pretrained(folder) model.save_pretrained(folder, safe_serialization=True) if is_sd: model.to(torch_dtype=torch.float16) else: model.half() model.save_pretrained(folder, variant="fp16") model.save_pretrained(folder, safe_serialization=True, variant="fp16") api.upload_folder( folder_path=folder, repo_id=model_id, repo_type="model", create_pr=True, ) print(model_id) if __name__ == "__main__": DESCRIPTION = """ Simple utility tool to convert automatically some weights on the hub to `safetensors` format. It is PyTorch exclusive for now. It works by downloading the weights (PT), converting them locally, and uploading them back as a PR on the hub. """ parser = argparse.ArgumentParser(description=DESCRIPTION) parser.add_argument( "model_id", type=str, help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", ) args = parser.parse_args() model_id = args.model_id api = HfApi() convert(api, model_id)