R-Detect / data_loader.py
songyiliao's picture
feat: initial cmommit (#1)
1244519 verified
import random
import tqdm
import datasets
import re
import transformers
import numpy as np
from utils import MGT, HWT, config
preproc_tokenizer = transformers.AutoTokenizer.from_pretrained(
"google-t5/t5-small", model_max_length=512
def process_spaces(text):
text = (
text.replace(" ,", ",")
.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ;", ";")
.replace(" '", "'")
.replace(" ’ ", "'")
.replace(" :", ":")
.replace("<newline>", "\n")
.replace("`` ", '"')
.replace(" ''", '"')
.replace("''", '"')
.replace(".. ", "... ")
.replace(" )", ")")
.replace("( ", "(")
.replace(" n't", "n't")
.replace(" i ", " I ")
.replace(" i'", " I'")
.replace("\\'", "'")
.replace("\n ", "\n")
text = text.replace("\r\n", "\n").replace("\\n", "").replace("!\n", "")
return re.sub("\n+", "\n", text)
def trim_to_shorter_length(texta, textb):
# truncate to shorter of o and s
shorter_length = min(len(texta.split(" ")), len(textb.split(" ")))
texta = " ".join(texta.split(" ")[:shorter_length])
textb = " ".join(textb.split(" ")[:shorter_length])
return texta, textb
def load_HC3():
if config["local_dataset"]:
print("Loading local HC3 dataset", config["local_dataset"])
print("Loading remote HC3 dataset")
ds = (
config["local_dataset"], name="all", trust_remote_code=True
if config["local_dataset"]
else datasets.load_dataset("Hello-SimpleAI/HC3", name="all")
ds = ds["train"] # DatasetDict -> Dataset
filtered_ds = [
for item in ds
if (
len(item["human_answers"]) > 0
and len(item["chatgpt_answers"]) > 0
and len(item["human_answers"][0].split()) > 5
and len(item["chatgpt_answers"][0].split()) > 5
# print("DEBUG: filtered_ds[0]:", filtered_ds[0])
data_new = {"text": [], "label": []}
for i in tqdm.tqdm(range(len(filtered_ds)), desc="Parsing data"):
return data_new
def filter_data(data_o, long_train_threshold_low=150, long_train_threshold_high=512):
data_HWT = [
text for text, label in zip(data_o["text"], data_o["label"]) if label == HWT
data_MGT = [
text for text, label in zip(data_o["text"], data_o["label"]) if label == MGT
# keep only examples with <= 512 tokens according to mask_tokenizer
# this step has the extra effect of removing examples with low-quality/garbage content
tokenized_data = preproc_tokenizer(data_HWT)
long_HWT = [
for x, y in zip(data_HWT, tokenized_data["input_ids"])
if long_train_threshold_low <= len(y) <= long_train_threshold_high
tokenized_data = preproc_tokenizer(data_MGT)
long_MGT = [
for x, y in zip(data_MGT, tokenized_data["input_ids"])
if long_train_threshold_low <= len(y) <= long_train_threshold_high
# print stats about remainining data
print(f"Total number of samples: {len(long_HWT)}")
print(f"Average number of words: {np.mean([len(x.split()) for x in long_HWT])}")
data = {
HWT: [],
MGT: [],
# print(len(long_HWT), len(long_MGT))
for o, s in zip(long_HWT, long_MGT):
o, s = trim_to_shorter_length(o, s)
# add to the data
return data
# Test code
# data_o = load_HC3()
# data = filter_data(data_o)
# real = data[HWT] # [:args.train_real_num] len== n_samples, many sentences of words
# generated = data[MGT]
# print(real[:5])
# print(generated[:5])