File size: 7,171 Bytes
ce13d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Custom exceptions for the LLMFoundry."""
from collections.abc import Mapping
from typing import Any, Dict, List

class MissingHuggingFaceURLSplitError(ValueError):
    """Error thrown when there's no split used in HF dataset config."""

    def __init__(self) -> None:
        message = 'When using a HuggingFace dataset from a URL, you must set the ' + '`split` key in the dataset config.'
        super().__init__(message)

class NotEnoughDatasetSamplesError(ValueError):
    """Error thrown when there is not enough data to train a model."""

    def __init__(self, dataset_name: str, split: str, dataloader_batch_size: int, world_size: int, full_dataset_size: int, minimum_dataset_size: int) -> None:
        self.dataset_name = dataset_name
        self.split = split
        self.dataloader_batch_size = dataloader_batch_size
        self.world_size = world_size
        self.full_dataset_size = full_dataset_size
        self.minimum_dataset_size = minimum_dataset_size
        message = f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.'
        super().__init__(message)

class UnknownExampleTypeError(KeyError):
    """Error thrown when an unknown example type is used in a task."""

    def __init__(self, example: Mapping) -> None:
        message = f'Unknown example type example={example!r}'
        super().__init__(message)

class TooManyKeysInExampleError(ValueError):
    """Error thrown when a data sample has too many keys."""

    def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
        message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
        super().__init__(message)

class NotEnoughChatDataError(ValueError):
    """Error thrown when there is not enough chat data to train a model."""

    def __init__(self) -> None:
        message = 'Chat example must have at least two messages'
        super().__init__(message)

class ConsecutiveRepeatedChatRolesError(ValueError):
    """Error thrown when there are consecutive repeated chat roles."""

    def __init__(self, repeated_role: str) -> None:
        self.repeated_role = repeated_role
        message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
        super().__init__(message)

class InvalidLastChatMessageRoleError(ValueError):
    """Error thrown when the last message role in a chat example is invalid."""

    def __init__(self, last_role: str, expected_roles: set[str]) -> None:
        message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
        super().__init__(message)

class IncorrectMessageKeyQuantityError(ValueError):
    """Error thrown when a message has an incorrect number of keys."""

    def __init__(self, keys: List[str]) -> None:
        self.keys = keys
        message = f'Expected 2 keys in message, but found {len(keys)}'
        super().__init__(message)

class InvalidRoleError(ValueError):
    """Error thrown when a role is invalid."""

    def __init__(self, role: str, valid_roles: set[str]) -> None:
        self.role = role
        self.valid_roles = valid_roles
        message = f'Expected role to be one of {valid_roles} but found: {role}'
        super().__init__(message)

class InvalidContentTypeError(TypeError):
    """Error thrown when the content type is invalid."""

    def __init__(self, content_type: type) -> None:
        self.content_type = content_type
        message = f'Expected content to be a string, but found {content_type}'
        super().__init__(message)

class InvalidPromptTypeError(TypeError):
    """Error thrown when the prompt type is invalid."""

    def __init__(self, prompt_type: type) -> None:
        self.prompt_type = prompt_type
        message = f'Expected prompt to be a string, but found {prompt_type}'
        super().__init__(message)

class InvalidResponseTypeError(TypeError):
    """Error thrown when the response type is invalid."""

    def __init__(self, response_type: type) -> None:
        self.response_type = response_type
        message = f'Expected response to be a string, but found {response_type}'
        super().__init__(message)

class InvalidPromptResponseKeysError(ValueError):
    """Error thrown when missing expected prompt and response keys."""

    def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
        self.example = example
        message = f'Expected mapping={mapping!r} to have keys "prompt" and "response".'
        super().__init__(message)

class InvalidFileExtensionError(FileNotFoundError):
    """Error thrown when a file extension is not a safe extension."""

    def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
        self.dataset_name = dataset_name
        self.valid_extensions = valid_extensions
        message = f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.'
        super().__init__(message)

class UnableToProcessPromptResponseError(ValueError):
    """Error thrown when a prompt and response cannot be processed."""

    def __init__(self, input: Dict) -> None:
        message = f'Unable to extract prompt/response from {input}'
        super().__init__(message)

class ClusterDoesNotExistError(ValueError):
    """Error thrown when the cluster does not exist."""

    def __init__(self, cluster_id: str) -> None:
        self.cluster_id = cluster_id
        message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
        super().__init__(message)

class FailedToCreateSQLConnectionError(RuntimeError):
    """Error thrown when client can't sql connect to Databricks."""

    def __init__(self) -> None:
        message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
        super().__init__(message)

class FailedToConnectToDatabricksError(RuntimeError):
    """Error thrown when the client fails to connect to Databricks."""

    def __init__(self) -> None:
        message = 'Failed to create databricks connection. Check hostname and access token!'
        super().__init__(message)

class InputFolderMissingDataError(ValueError):
    """Error thrown when the input folder is missing data."""

    def __init__(self, input_folder: str) -> None:
        self.input_folder = input_folder
        message = f'No text files were found at {input_folder}.'
        super().__init__(message)

class OutputFolderNotEmptyError(FileExistsError):
    """Error thrown when the output folder is not empty."""

    def __init__(self, output_folder: str) -> None:
        self.output_folder = output_folder
        message = f'{output_folder} is not empty. Please remove or empty it and retry.'
        super().__init__(message)