Ngaima Sandiman
Changed transformer version to fix issues.
9cc3964
raw
history blame
903 Bytes
import json
import os
from typing import Optional
from PIL import Image
from src.model.modules.imagecraftconfig import ImageCraftConfig
from src.model.modules.imagecraftprocessor import (
ImageCraftProcessor,
)
def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(
processor: ImageCraftProcessor,
prompt: str,
image: Image,
device: str = "cuda",
):
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def get_config(config_file="config.json"):
config = None
with open(config_file, "r") as f:
model_config_file = json.load(f)
config = ImageCraftConfig(**model_config_file)
return config