File size: 759 Bytes
d628c92
 
 
 
 
 
 
5245b11
d628c92
 
 
 
2b7a8a5
d628c92
 
 
 
2b7a8a5
5245b11
 
d628c92
2b7a8a5
 
5245b11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import yaml

# Load in model configuration and check the required keys are present
model_config_dir = "config/model_config.yml"
config_keys = ["system_message", "model_id", "template"]

def load_config_values(
        model_config_dir=model_config_dir,
        config_keys=config_keys
    ):
    """
    Function to extract user-input keys from config.yml
    Returns values as a list with length = length(config_keys)
    """
    with open(model_config_dir, "r") as file:
        model_config = yaml.safe_load(file)

    values_list = []
    for var in config_keys:
        if var not in model_config.keys():
            raise ValueError(f"`{var}` key missing from `{model_config_dir}`")
        values_list.append(model_config[var])

    return values_list