File size: 1,850 Bytes
ce13d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
from typing import List, Optional
PROMPTFILE_PREFIX = 'file::'

def load_prompts(prompts: List[str], prompt_delimiter: Optional[str]=None) -> List[str]:
    """Loads a set of prompts, both free text and from file.

    Args:
        prompts (List[str]): List of free text prompts and prompt files
        prompt_delimiter (Optional str): Delimiter for text file
            If not provided, assumes the prompt file is a single prompt (non-delimited)

    Returns:
        List of prompt string(s)
    """
    prompt_strings = []
    for prompt in prompts:
        if prompt.startswith(PROMPTFILE_PREFIX):
            prompts = load_prompts_from_file(prompt, prompt_delimiter)
            prompt_strings.extend(prompts)
        else:
            prompt_strings.append(prompt)
    return prompt_strings

def load_prompts_from_file(prompt_path: str, prompt_delimiter: Optional[str]=None) -> List[str]:
    """Load a set of prompts from a text fie.

    Args:
        prompt_path (str): Path for text file
        prompt_delimiter (Optional str): Delimiter for text file
            If not provided, assumes the prompt file is a single prompt (non-delimited)

    Returns:
        List of prompt string(s)
    """
    if not prompt_path.startswith(PROMPTFILE_PREFIX):
        raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')
    _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
    prompt_file_path = os.path.expanduser(prompt_file_path)
    if not os.path.isfile(prompt_file_path):
        raise FileNotFoundError(f'prompt_file_path={prompt_file_path!r} does not match any existing files.')
    with open(prompt_file_path, 'r') as f:
        prompt_string = f.read()
    if prompt_delimiter is None:
        return [prompt_string]
    return [i for i in prompt_string.split(prompt_delimiter) if i]