File size: 9,065 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import logging
from typing import Dict, List

import numpy as np

from llm_studio.src.datasets.text_utils import get_texts
from llm_studio.src.utils.utils import PatchedAttribute

logger = logging.getLogger(__name__)


class ConversationChainHandler:
    """
    This class partitions the dataset into chains of conversations.
    Each chain is comprised of a list of conversation rounds.
    Each round within a conversation is represented as a triplet:
     (system, prompt, answer).

    The resulting structure of the chains is conditional on
    the DataFrame's structure and configuration:

    - Without a 'parent_id' in the DataFrame, each conversation chain is a single round.
     So, for every `i`-th row in the DataFrame, 0 <= `i` < len(df),
     the chain would look like: [(system_i, prompt_i, answer_i)]

    - With a 'parent_id' in the DataFrame and
      if `cfg.dataset.limit_chained_samples` is set to False,
      each chain encapsulates all preceding conversations
      for every `i`-th row in the DataFrame,
      0 <= `i` < len(df).
      The resultant chain would take shape:
          [(system_start_conversation_i,
            prompt_start_conversation_i,
            answer_start_conversation_i),
           ...,
           (system_i, prompt_i, answer_i)]

    - With a 'parent_id' in the DataFrame and
      if `cfg.dataset.limit_chained_samples` is set to True,
      each conversation chain incorporates only full conversations.
      The chain hence condenses into:
          [(system_start_conversation_i,
            prompt_start_conversation_i,
            answer_start_conversation_i),
           ...,
          (system_end_conversation_i,
           prompt_end_conversation_i,
           answer_end_conversation_i)]
      where `i` represents complete conversations only.
    """

    def __init__(
        self,
        df,
        cfg,
    ):
        # Do not set self.cfg = cfg, as ConversationChainHandler
        # will be used with PatchedAttribute context manager.
        self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df)
        self.prompts = get_texts(df, cfg, separator="")
        self.answers = self.get_answers(df, cfg)
        self.systems = self.get_systems(cfg, df)

    def get_conversation_chain_ids(self, cfg, df):
        """
        Gets the conversation chain IDs for the given DataFrame.
        E.g. if conversation_chain_ids = [[13, 44, 8], ...],
        then the first conversation chain consists of
        [df.iloc[13], df.iloc[44], df.iloc[8]]
        with
            - df.iloc[13] denotes the first round of the conversation
            - df.iloc[44] denotes the second round of the conversation
            - df.iloc[8] denotes the end of the conversation
        if limit_chained_samples is True, df.iloc[13] will have no parent_id,
        i.e. it is the start of the conversation.
        """
        if (
            cfg.dataset.parent_id_column in ["None", None]
            # Handle case where train Dataframe has conversation chains,
            # but val Dataframe does not
            or cfg.dataset.parent_id_column not in df.columns
        ):
            # no parent id column, so each triplet (system_i, prompt_i, answer_i)
            # is a conversation chain
            return [[idx] for idx in range(len(df))]

        assert "id" in df.columns, (
            f"id column required for conversation chaining, "
            f"DataFrame only has {df.columns}."
        )
        # sample and parent ids can have any dtype, such as str, int, float, etc.
        # id column can be int, while parent_id column can be float
        # (as some values are NaN) so we cast id to the same dtype
        sample_ids = df["id"].astype(df[cfg.dataset.parent_id_column].dtype).tolist()
        parent_ids = df[cfg.dataset.parent_id_column].tolist()
        # Some datasets may include parent ids that are not in the dataset.
        sample_ids_set = set(sample_ids)
        parent_ids = [idx if idx in sample_ids_set else "None" for idx in parent_ids]

        id2parent_id = {
            idx: parent_id
            for idx, parent_id in zip(sample_ids, parent_ids)
            if parent_id not in [None, "None"]
            and (
                not isinstance(parent_id, float)
                or (not np.isnan(parent_id) and not np.isinf(parent_id))
            )
        }
        if cfg.dataset.limit_chained_samples:
            # end id == id is not a parent id of another conversation id
            valid_parent_ids = set(id2parent_id.values())
            conversation_end_ids = [
                idx for idx in sample_ids if idx not in valid_parent_ids
            ]
        else:
            conversation_end_ids = sample_ids
        conversation_chain_ids = [
            self.get_conversation_ids(id2parent_id, conversation_end_id)
            for conversation_end_id in conversation_end_ids
        ]
        # map from df["id"] to enumeration index
        dataframeid2idx = {id: idx for idx, id in enumerate(sample_ids)}
        conversation_chain_ids = [
            [dataframeid2idx[conversation_id] for conversation_id in conversation_ids]
            for conversation_ids in conversation_chain_ids
        ]
        return conversation_chain_ids

    def get_answers(self, df, cfg):
        answer_column = cfg.dataset.answer_column
        if answer_column in df.columns:
            answers = df[answer_column].astype(str).tolist()
        else:
            answers = ["" for _ in range(len(self.prompts))]
        return answers

    def get_systems(self, cfg, df):
        if cfg.dataset.system_column != "None":
            if cfg.dataset.system_column not in df.columns:
                logger.warning(
                    f"System column {cfg.dataset.system_column} not found."
                    f"Disabling functionality."
                )
                systems = ["" for _ in range(len(self.prompts))]
            else:
                systems = df[cfg.dataset.system_column].astype(str).tolist()
        else:
            systems = ["" for _ in range(len(self.prompts))]
        return systems

    @staticmethod
    def get_conversation_ids(id2parent_id, end_id):
        """
        Gets the conversation chain for a given starting conversation ID.
        Args:
            id2parent_id: A dictionary containing the mapping of IDs
            to its previous parent ID.
            end_id: The ID of the end of the conversation in the chain.
        Returns:
            A list of conversation IDs representing the conversation chain.
            The chain is ordered from the first conversation id to end_id in the chain.
        """
        # prevent infinite loops in case
        # of circular parent chains (dataframe issue)
        loop_counter = 0

        conversation_chain_ids = [end_id]
        parent_id = end_id
        while parent_id in id2parent_id:
            loop_counter += 1

            parent_id = id2parent_id[parent_id]
            conversation_chain_ids = [parent_id] + conversation_chain_ids
            if loop_counter > 1000:
                raise ValueError(
                    f"Parent chain of sample with idx {end_id} "
                    f"exceeds max loop count of 1000. "
                    f"Please ensure that parent chain is not circular."
                )
        return conversation_chain_ids

    def __len__(self):
        return len(self.conversation_chain_ids)

    def __getitem__(self, idx):
        """
        Gets a single conversation chain.
        The conversation may be:
        - a single (system, prompt, answer) round,
          if cfg.dataset.parent_id_column == "None" or
          there is no parent_id for the conversation
        - a conversation potentially starting somewhere in
          the middle of the conversation, if the conversation
          is chained and limit_chained_samples is set to False
        - always a complete conversation, if the conversation is chained
          and limit_chained_samples is True

        """
        prompts = [self.prompts[i] for i in self.conversation_chain_ids[idx]]
        answers = [self.answers[i] for i in self.conversation_chain_ids[idx]]
        systems = [self.systems[i] for i in self.conversation_chain_ids[idx]]
        return {
            "prompts": prompts,
            "answers": answers,
            "systems": systems,
        }

    def get_conversation_end_ids(self):
        """
        Gets the end conversation IDs for each conversation chain.
        """
        return [
            conversation_chain[-1] for conversation_chain in self.conversation_chain_ids
        ]


def get_conversation_chains(
    df, cfg, limit_chained_samples=True
) -> List[Dict[str, List[str]]]:
    with PatchedAttribute(cfg.dataset, "limit_chained_samples", limit_chained_samples):
        conversation_chain_handler = ConversationChainHandler(df, cfg)
    conversations = [
        conversation
        for conversation in conversation_chain_handler  # type: ignore[attr-defined]
    ]
    return conversations