File size: 903 Bytes
685ecb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc3964
685ecb2
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
import os
from typing import Optional
from PIL import Image


from src.model.modules.imagecraftconfig import ImageCraftConfig
from src.model.modules.imagecraftprocessor import (
    ImageCraftProcessor,
)


def move_inputs_to_device(model_inputs: dict, device: str):
    model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
    return model_inputs


def get_model_inputs(
    processor: ImageCraftProcessor,
    prompt: str,
    image: Image,
    device: str = "cuda",
):
    images = [image]
    prompts = [prompt]

    model_inputs = processor(text=prompts, images=images)
    model_inputs = move_inputs_to_device(model_inputs, device)
    return model_inputs


def get_config(config_file="config.json"):
    config = None
    with open(config_file, "r") as f:
        model_config_file = json.load(f)
        config = ImageCraftConfig(**model_config_file)

    return config