|
from datasets import Dataset, DatasetDict |
|
import pandas as pd |
|
import numpy as np |
|
import glob |
|
from sklearn.model_selection import train_test_split |
|
import re |
|
|
|
datapath = '/cluster/work/lawecon/Work/penghao/dataset/stories/' |
|
pairpath = '../../../work/lawecon/Work/penghao/pairs.csv' |
|
|
|
|
|
|
|
class StoryPairDataset(Dataset): |
|
def __init__(self, datapath, pairpath, tokenizer, task, used_dataset_size=-1, train_test_split=0.1, |
|
split_by='random', |
|
max_len=4096*2, mode='m3', max_time_window=3000, least_likes=5, margin=True): |
|
self.datapath = datapath |
|
print(self.datapath) |
|
self.train_test_split = train_test_split |
|
self.pairpath = pairpath |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
self.split_by = split_by |
|
self.least_likes = least_likes |
|
self.max_time_window = max_time_window |
|
self.used_dataset_size = used_dataset_size |
|
if mode == 'm2': |
|
self.max_time_window = 12009600 |
|
else: |
|
self.max_time_window = max_time_window |
|
self.pair = self.load_pair() |
|
|
|
self.task = task |
|
self.margin = margin |
|
self.stories = self.load_stories(self.datapath) |
|
print(self.stories.columns) |
|
print(len(self.stories)) |
|
|
|
|
|
|
|
|
|
|
|
self.train, self.test = self.train_test_split__() |
|
self.train = self.marginInclude(self.train) |
|
self.test = self.marginInclude(self.test) |
|
|
|
self.dataset = self.make_dataset() |
|
print('current setting mode is ', mode) |
|
print('currnet setting split_by is ', split_by) |
|
print('current setting least_likes is ', least_likes) |
|
|
|
|
|
|
|
def load_stories(self, path): |
|
stories = pd.DataFrame() |
|
|
|
for file in glob.glob(path + '*.csv'): |
|
|
|
try: |
|
|
|
df = pd.read_csv(file) |
|
|
|
|
|
if df.empty: |
|
print(f"Warning: {file} is empty or not readable.") |
|
continue |
|
|
|
stories = pd.concat([stories, df], ignore_index=True) |
|
except pd.errors.EmptyDataError: |
|
|
|
pass |
|
except pd.errors.ParserError: |
|
print(f"Error: {file} cannot be parsed.") |
|
except Exception as e: |
|
print(f"Error: An unexpected error occurred while processing {file}. Details: {str(e)}") |
|
|
|
|
|
return stories |
|
|
|
def load_pair(self): |
|
|
|
pair = pd.read_csv(self.pairpath) |
|
|
|
|
|
pair = pair[pair['time_lag'] <= self.max_time_window] |
|
print('the max of tima lag is ', pair['time_lag'].max()) |
|
pair = pair[pair['least_likes'] >= self.least_likes] |
|
|
|
pair.loc[pair['rel'] < 0, ['story1_id', 'story2_id']] = pair.loc[ |
|
pair['rel'] < 0, ['story2_id', 'story1_id']].values |
|
pair['rel'] = abs(pair['rel']) |
|
|
|
pair = pair[pair['story1_id'] != pair['story2_id']] |
|
if self.used_dataset_size == -1: |
|
self.used_dataset_size = len(pair) |
|
else: |
|
pair = pair.sample(n=self.used_dataset_size) |
|
print('the total number of pairs is ', len(pair)) |
|
|
|
pair = pair.drop_duplicates(subset=['story1_id', 'story2_id']) |
|
|
|
pair = pair[pair['rel'] != 0] |
|
print('the number of effective pairs is ', len(pair)) |
|
return pair |
|
|
|
def marginInclude(self, df): |
|
if self.margin: |
|
|
|
df = df.drop(columns=['rel']) |
|
else: |
|
|
|
df = df.rename(columns={'rel': 'margin'}) |
|
return df |
|
|
|
def train_test_split__(self): |
|
''' |
|
split the pairs into train and test set |
|
:return: |
|
''' |
|
test_size = round(len(self.pair) * self.train_test_split) |
|
|
|
if self.split_by == 'time': |
|
|
|
self.stories['posted_date'] = pd.to_datetime(self.stories['posted_date']) |
|
|
|
self.stories['posted_date'] = self.stories['posted_date'].dt.strftime('%Y%m%d') |
|
|
|
|
|
|
|
test = self.pair[self.pair['story1_id'].apply(lambda x: int(self.stories[self.stories['story_id'] == x]['posted_date'].values[0]) > 20220000)] |
|
train = self.pair[self.pair['story1_id'].apply(lambda x: int(self.stories[self.stories['story_id'] == x]['posted_date'].values[0]) <= 20220000)] |
|
print('the number of test set is ', len(test)) |
|
print('the number of train set is ', len(train)) |
|
print('the ratio of test set is ', len(test) / (len(test) + len(train))) |
|
|
|
elif self.split_by == 'random': |
|
|
|
train, test = train_test_split(self.pair, test_size=self.train_test_split) |
|
|
|
|
|
|
|
|
|
elif self.split_by == 'genre': |
|
|
|
|
|
|
|
self.pair['genre'] = self.pair['story1_id'].apply( |
|
lambda x: self.stories[self.stories['story_id'] == x]['genre'].values[0]) |
|
genre = {} |
|
for c in self.pair['genre'].unique(): |
|
genre[c] = len(self.pair[self.pair['genre'] == c]) |
|
|
|
genre = dict(sorted(genre.items(), key=lambda item: item[1], reverse=True)) |
|
print(genre) |
|
total = sum(genre.values()) |
|
|
|
test_genre = [] |
|
test_count = 0 |
|
while test_count < total * self.train_test_split: |
|
test_genre.append(list(genre.keys())[0]) |
|
test_count += genre[list(genre.keys())[0]] |
|
del genre[list(genre.keys())[0]] |
|
if test_count + genre[list(genre.keys())[0]] > total * self.train_test_split: |
|
break |
|
|
|
test = self.pair[self.pair['genre'].apply(lambda x: x in test_genre)] |
|
train = self.pair[self.pair['genre'].apply(lambda x: x not in test_genre)] |
|
print('the genre of test set is ', test_genre) |
|
print('the percentage of test set is ', test_count / total,'where total is ', total) |
|
|
|
elif self.split_by == 'chaos': |
|
|
|
for i in range(len(self.pair)): |
|
self.pair.at[i, 'story1_id'] = np.random.choice(self.stories[self.stories['prompt_id'] == self.pair.at[i, 'prompt_id']]['story_id'].values) |
|
self.pair.at[i, 'story2_id'] = np.random.choice(self.stories[self.stories['prompt_id'] == self.pair.at[i, 'prompt_id']]['story_id'].values) |
|
train, test = train_test_split(self.pair, test_size=self.train_test_split) |
|
return train, test |
|
|
|
def apply_template_to_text(self, row): |
|
|
|
|
|
prompt_id, story1_id, story2_id = row[['prompt_id', 'story1_id', 'story2_id']] |
|
|
|
|
|
|
|
chosen_prompt = self.stories[self.stories['prompt_id'] == prompt_id]['prompt'] |
|
chosen_prompt = chosen_prompt.values[0] |
|
chosen_story = self.stories[self.stories['story_id'] == story1_id]['story_title'].values[0] + '/n' + \ |
|
self.stories[self.stories['story_id'] == story1_id]['story_text'].values[0] |
|
|
|
rejected_prompt = self.stories[self.stories['prompt_id'] == prompt_id]['prompt'] |
|
rejected_prompt = rejected_prompt.values[0] |
|
rejected_story = self.stories[self.stories['story_id'] == story2_id]['story_title'].values[0] + '/n' + \ |
|
self.stories[self.stories['story_id'] == story2_id]['story_text'].values[0] |
|
|
|
|
|
chosen_text = [{'role': 'user', 'content': chosen_prompt}, |
|
{'role': 'assistant', 'content': chosen_story}] |
|
|
|
rejected_text = [{'role': 'user', 'content': rejected_prompt}, |
|
{'role': 'assistant', 'content': rejected_story}] |
|
|
|
|
|
chosen_text = self.tokenizer.apply_chat_template(chosen_text, tokenize=False) |
|
rejected_text = self.tokenizer.apply_chat_template(rejected_text, tokenize=False) |
|
|
|
res = {} |
|
res['chosen_text'] = chosen_text |
|
res['rejected_text'] = rejected_text |
|
|
|
res['chosen_text'] = self.tokenizer.bos_token + res['chosen_text'] + self.tokenizer.eos_token |
|
res['rejected_text'] = self.tokenizer.bos_token + res['rejected_text'] + self.tokenizer.eos_token |
|
|
|
res['text'] = chosen_text |
|
|
|
res['text'] = self.tokenizer.bos_token + res['text'] + self.tokenizer.eos_token |
|
if 'gemma' in self.tokenizer.name_or_path: |
|
split_words = '<|im_start|>assistant\n' |
|
elif 'mistral' in self.tokenizer.name_or_path or 'llama' in self.tokenizer.name_or_path: |
|
split_words = '[/INST]' |
|
|
|
chosen_text_tmp = chosen_text.split(split_words)[-1] |
|
prompt_text = chosen_text.replace(chosen_text_tmp, '') |
|
chosen_text = chosen_text_tmp |
|
|
|
rejected_text = rejected_text.split(split_words)[-1] |
|
res['prompt'] = prompt_text |
|
res['chosen'] = chosen_text |
|
res['rejected'] = rejected_text |
|
|
|
res['prompt'] = self.tokenizer.bos_token + res['prompt'] |
|
res['chosen'] = res['chosen'] + self.tokenizer.eos_token |
|
res['rejected'] = res['rejected'] + self.tokenizer.eos_token |
|
return res |
|
|
|
def convert_sft(self,df): |
|
|
|
story_ids = list(set(df['story1_id'].values) | set(df['story2_id'].values)) |
|
|
|
df = pd.DataFrame() |
|
df['story1_id'] = story_ids |
|
df['story2_id'] = df['story1_id'] |
|
|
|
|
|
|
|
def get_prompt_id(x): |
|
return self.stories[self.stories['story_id'] == x]['prompt_id'].values[0] |
|
df['prompt_id'] = df['story1_id'].apply(lambda x: get_prompt_id(x)) |
|
return df |
|
|
|
|
|
|
|
def make_dataset(self): |
|
|
|
self.train.reset_index(drop=True, inplace=True) |
|
self.test.reset_index(drop=True, inplace=True) |
|
entries = [] |
|
if self.task == 'rm': |
|
entries = ['chosen_text', 'rejected_text'] |
|
elif self.task == 'dpo': |
|
entries = ['prompt', 'chosen', 'rejected'] |
|
elif self.task == 'sft': |
|
self.train = self.convert_sft(self.train) |
|
self.test = self.convert_sft(self.test) |
|
entries = ['text'] |
|
|
|
print('the columns of train is ', self.train.columns) |
|
for index, row in self.train.iterrows(): |
|
res = self.apply_template_to_text(row) |
|
for e in entries: |
|
self.train.at[index, e] = res[e] |
|
|
|
for index, row in self.test.iterrows(): |
|
res = self.apply_template_to_text(row) |
|
for e in entries: |
|
self.test.at[index, e] = res[e] |
|
|
|
print('the first example of train is ', self.train.iloc[0]) |
|
|
|
|
|
if self.margin: |
|
entries.append('margin') |
|
|
|
train_dataset = Dataset.from_pandas(self.train[entries]) |
|
test_dataset = Dataset.from_pandas(self.test[entries]) |
|
|
|
return DatasetDict({'train': train_dataset, 'test': test_dataset}) |
|
|
|
def save_dataset(self, path): |
|
''' |
|
save the dataset to the readsy folder |
|
:param path: |
|
:return: |
|
''' |
|
self.dataset.save_to_disk('../' + path) |
|
|