File size: 903 Bytes
685ecb2 9cc3964 685ecb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import json
import os
from typing import Optional
from PIL import Image
from src.model.modules.imagecraftconfig import ImageCraftConfig
from src.model.modules.imagecraftprocessor import (
ImageCraftProcessor,
)
def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(
processor: ImageCraftProcessor,
prompt: str,
image: Image,
device: str = "cuda",
):
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def get_config(config_file="config.json"):
config = None
with open(config_file, "r") as f:
model_config_file = json.load(f)
config = ImageCraftConfig(**model_config_file)
return config
|