Spaces:
Runtime error
Runtime error
File size: 1,810 Bytes
c700ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
import json
from typing import Dict
from torch.utils.data import Dataset
from datasets import Dataset as AdvancedDataset
from transformers import AutoTokenizer, AutoModelForCausalLM
DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json"
DEFAULT_TEST_DATA_NAME = "train_openprompt.json"
DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json"
def get_open_prompt_data(path_for_data):
with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f:
train_data = json.load(f)
with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f:
test_data = json.load(f)
return train_data, test_data
def get_tok_and_model(path_for_model):
if not os.path.exists(path_for_model):
raise RuntimeError("no cached model.")
tok = AutoTokenizer.from_pretrained(path_for_model, padding_side='left')
tok.pad_token_id = 50256
# default for open-ended generation
model = AutoModelForCausalLM.from_pretrained(path_for_model)
return tok, model
class OpenPromptDataset(Dataset):
def __init__(self, data) -> None:
super().__init__()
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def get_dataset(train_data, test_data):
train_dataset = OpenPromptDataset(train_data)
test_dataset = OpenPromptDataset(test_data)
return train_dataset, test_dataset
def get_dict_dataset(path_for_data):
with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f:
dict_data = json.load(f)
return dict_data
def get_advance_dataset(dict_data):
if not isinstance(dict_data, Dict):
raise RuntimeError("dict_data is not a dict.")
dataset = AdvancedDataset.from_dict(dict_data)
return dataset |