tools / convert_flax_to_pt.py
patrickvonplaten's picture
up
40d1ba9
raw
history blame
No virus
2.18 kB
import argparse
import json
import os
import shutil
from diffusers.pipelines.stable_diffusion import safety_checker
import torch
from tempfile import TemporaryDirectory
from typing import List, Optional
from diffusers import StableDiffusionPipeline, ControlNetModel
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
is_sd = "model_index.json" in filenames
if is_sd:
model = StableDiffusionPipeline.from_pretrained(model_id, from_flax=True, safety_checker=None)
else:
model = ControlNetModel.from_pretrained(model_id, from_flax=True)
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
model.save_pretrained(folder)
model.save_pretrained(folder, safe_serialization=True)
if is_sd:
model.to(torch_dtype=torch.float16)
else:
model.half()
model.save_pretrained(folder, variant="fp16")
model.save_pretrained(folder, safe_serialization=True, variant="fp16")
api.upload_folder(
folder_path=folder,
repo_id=model_id,
repo_type="model",
create_pr=True,
)
print(model_id)
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
args = parser.parse_args()
model_id = args.model_id
api = HfApi()
convert(api, model_id)