Spaces:
Runtime error
Runtime error
import logging | |
import pandas as pd | |
from datasets import load_metric | |
from os.path import exists | |
from os.path import join as pjoin | |
import utils | |
from utils import dataset_utils as ds_utils | |
logs = utils.prepare_logging(__file__) | |
TOK_MODEL = "gpt2" | |
PERPLEXITY = load_metric("perplexity") | |
PERPLEXITY_FIELD = "perplexity" | |
class DMTHelper: | |
def __init__(self, dstats, load_only=False): | |
self.dstats = dstats | |
self.load_only = load_only | |
self.results_dict = {} | |
# Where in the Dataset object to find the text for the calculation | |
self.text_field = ds_utils.OUR_TEXT_FIELD | |
# Results in dataframe form | |
self.df = None | |
# Cache file | |
self.perplexities_df_fid = pjoin(self.dstats.dataset_cache_dir, | |
"perplexities_df.json") | |
def run_DMT_processing(self): | |
if self.dstats.use_cache and exists(self.perplexities_df_fid): | |
self.df = ds_utils.read_df(self.perplexities_df_fid) | |
elif not self.load_only: | |
self.prepare_text_perplexities() | |
if self.dstats.save: | |
ds_utils.write_df(self.df, self.perplexities_df_fid) | |
def prepare_text_perplexities(self): | |
texts = self.dstats.text_dset[self.text_field] | |
eval_results = PERPLEXITY.compute(input_texts=texts, model_id=TOK_MODEL) | |
# TODO: What other stuff might be useful to grab? | |
self.results_dict = {PERPLEXITY_FIELD: eval_results["perplexities"], | |
self.text_field: self.dstats.text_dset[self.text_field]} | |
self.df = pd.DataFrame(self.results_dict).sort_values( | |
by=PERPLEXITY_FIELD, ascending=False) | |
def get_df(self): | |
return self.df | |