Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
from typing import Optional | |
from PIL import Image | |
from src.model.modules.imagecraftconfig import ImageCraftConfig | |
from src.model.modules.imagecraftprocessor import ( | |
ImageCraftProcessor, | |
) | |
def move_inputs_to_device(model_inputs: dict, device: str): | |
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
return model_inputs | |
def get_model_inputs( | |
processor: ImageCraftProcessor, | |
prompt: str, | |
image: Image, | |
suffix: Optional[str] = None, | |
device: str = "cuda", | |
): | |
images = [image] | |
prompts = [prompt] | |
if suffix is not None: | |
suffix = [suffix] | |
model_inputs = processor(text=prompts, images=images) | |
model_inputs = move_inputs_to_device(model_inputs, device) | |
return model_inputs | |
def get_config(config_file="config.json"): | |
config = None | |
with open(config_file, "r") as f: | |
model_config_file = json.load(f) | |
config = ImageCraftConfig(**model_config_file) | |
return config | |
# def load_hf_model(model_path: str, device: str) -> Tuple[ImageCraft, AutoTokenizer]: | |
# # Load the tokenizer | |
# tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") | |
# assert tokenizer.padding_side == "right" | |
# # Find all the *.safetensors files | |
# safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors")) | |
# # ... and load them one by one in the tensors dictionary | |
# tensors = {} | |
# for safetensors_file in safetensors_files: | |
# with safe_open(safetensors_file, framework="pt", device="cpu") as f: | |
# for key in f.keys(): | |
# tensors[key] = f.get_tensor(key) | |
# # Load the model's config | |
# with open(os.path.join(model_path, "config.json"), "r") as f: | |
# model_config_file = json.load(f) | |
# config = ImageCraftConfig(**model_config_file) | |
# # Create the model using the configuration | |
# model = ImageCraft(config).to(device) | |
# # Load the state dict of the model | |
# model.load_state_dict(tensors, strict=False) | |
# # Tie weights | |
# model.tie_weights() | |
# return (model, tokenizer) | |