import json import os from typing import Optional from PIL import Image from src.model.modules.imagecraftconfig import ImageCraftConfig from src.model.modules.imagecraftprocessor import ( ImageCraftProcessor, ) def move_inputs_to_device(model_inputs: dict, device: str): model_inputs = {k: v.to(device) for k, v in model_inputs.items()} return model_inputs def get_model_inputs( processor: ImageCraftProcessor, prompt: str, image: Image, device: str = "cuda", ): images = [image] prompts = [prompt] model_inputs = processor(text=prompts, images=images) model_inputs = move_inputs_to_device(model_inputs, device) return model_inputs def get_config(config_file="config.json"): config = None with open(config_file, "r") as f: model_config_file = json.load(f) config = ImageCraftConfig(**model_config_file) return config