Spaces:
Runtime error
Runtime error
File size: 4,331 Bytes
3355824 3076727 3355824 4fd751f 3355824 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
This code a slight modification of perplexity by hugging face
https://huggingface.co/docs/transformers/perplexity
Both this code and the orignal code are published under the MIT license.
by Burhan Ul tayyab and Nicholas Chua
"""
import torch
import re
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from collections import OrderedDict
class GPT2PPL:
def __init__(self, device="cpu", model_id="gpt2"):
self.device = device
self.model_id = model_id
self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
self.max_length = self.model.config.n_positions
self.stride = 512
def getResults(self, threshold):
# if threshold < 60:
# label = 0
# return "The Text is generated by AI.", label
# elif threshold < 80:
# label = 0
# return "The Text is most probably contain parts which are generated by AI. (require more text for better Judgement)", label
# else:
# label = 1
# return "The Text is written by Human.", label
return {"HUMAN": threshold, "AI": 1 - threshold}
def __call__(self, sentence):
"""
Takes in a sentence split by full stop
and print the perplexity of the total sentence
split the lines based on full stop and find the perplexity of each sentence and print
average perplexity
Burstiness is the max perplexity of each sentence
"""
results = OrderedDict()
total_valid_char = re.findall("[a-zA-Z0-9]+", sentence)
total_valid_char = sum([len(x) for x in total_valid_char]) # finds len of all the valid characters a sentence
if total_valid_char < 100:
return {"status": "Please input more text (min 100 characters)"}, "Please input more text (min 100 characters)"
lines = re.split(r'(?<=[.?!][ \[\(])|(?<=\n)\s*',sentence)
lines = list(filter(lambda x: (x is not None) and (len(x) > 0), lines))
ppl = self.getPPL(sentence)
print(f"Perplexity {ppl}")
results["Perplexity"] = ppl
offset = ""
Perplexity_per_line = []
for i, line in enumerate(lines):
if re.search("[a-zA-Z0-9]+", line) == None:
continue
if len(offset) > 0:
line = offset + line
offset = ""
# remove the new line pr space in the first sentence if exists
if line[0] == "\n" or line[0] == " ":
line = line[1:]
if line[-1] == "\n" or line[-1] == " ":
line = line[:-1]
elif line[-1] == "[" or line[-1] == "(":
offset = line[-1]
line = line[:-1]
ppl = self.getPPL(line)
Perplexity_per_line.append(ppl)
print(f"Perplexity per line {sum(Perplexity_per_line)/len(Perplexity_per_line)}")
results["Perplexity per line"] = sum(Perplexity_per_line)/len(Perplexity_per_line)
print(f"Burstiness {max(Perplexity_per_line)}")
results["Burstiness"] = max(Perplexity_per_line)
out, label = self.getResults(results["Perplexity per line"])
results["label"] = label
return results, out
def getPPL(self,sentence):
encodings = self.tokenizer(sentence, return_tensors="pt")
seq_len = encodings.input_ids.size(1)
nlls = []
likelihoods = []
prev_end_loc = 0
for begin_loc in range(0, seq_len, self.stride):
end_loc = min(begin_loc + self.max_length, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = self.model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss * trg_len
likelihoods.append(neg_log_likelihood)
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = int(torch.exp(torch.stack(nlls).sum() / end_loc))
return ppl |