Ngaima Sandiman
Initial commit.
685ecb2
raw
history blame
2.12 kB
import json
import os
from typing import Optional
from PIL import Image
from src.model.modules.imagecraftconfig import ImageCraftConfig
from src.model.modules.imagecraftprocessor import (
ImageCraftProcessor,
)
def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(
processor: ImageCraftProcessor,
prompt: str,
image: Image,
suffix: Optional[str] = None,
device: str = "cuda",
):
images = [image]
prompts = [prompt]
if suffix is not None:
suffix = [suffix]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def get_config(config_file="config.json"):
config = None
with open(config_file, "r") as f:
model_config_file = json.load(f)
config = ImageCraftConfig(**model_config_file)
return config
# def load_hf_model(model_path: str, device: str) -> Tuple[ImageCraft, AutoTokenizer]:
# # Load the tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
# assert tokenizer.padding_side == "right"
# # Find all the *.safetensors files
# safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
# # ... and load them one by one in the tensors dictionary
# tensors = {}
# for safetensors_file in safetensors_files:
# with safe_open(safetensors_file, framework="pt", device="cpu") as f:
# for key in f.keys():
# tensors[key] = f.get_tensor(key)
# # Load the model's config
# with open(os.path.join(model_path, "config.json"), "r") as f:
# model_config_file = json.load(f)
# config = ImageCraftConfig(**model_config_file)
# # Create the model using the configuration
# model = ImageCraft(config).to(device)
# # Load the state dict of the model
# model.load_state_dict(tensors, strict=False)
# # Tie weights
# model.tie_weights()
# return (model, tokenizer)