Ezi's picture
Upload 312 files
46df0b6
import logging
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
import seaborn as sns
import statistics
from os.path import join as pjoin
import pandas as pd
import utils
from utils import dataset_utils as ds_utils
from collections import Counter
from os.path import exists, isdir
from os.path import join as pjoin
TEXT_FIELD = "text"
TOKENIZED_FIELD = "tokenized_text"
LENGTH_FIELD = "length"
UNIQ = "num_instance_lengths"
AVG = "average_instance_length"
STD = "standard_dev_instance_length"
logs = utils.prepare_logging(__file__)
def make_fig_lengths(lengths_df):
# How the hell is this working? plt transforms to sns ?!
logs.info("Creating lengths figure.")
plt.switch_backend('Agg')
fig_tok_lengths, axs = plt.subplots(figsize=(15, 6), dpi=150)
plt.xlabel("Number of tokens")
plt.title("Binned counts of text lengths, with kernel density estimate and ticks for each instance.")
sns.histplot(data=lengths_df, kde=True, ax=axs, x=LENGTH_FIELD, legend=False)
sns.rugplot(data=lengths_df, ax=axs)
return fig_tok_lengths
class DMTHelper:
def __init__(self, dstats, load_only=False, save=True):
self.tokenized_df = dstats.tokenized_df
# Whether to only use cache
self.load_only = load_only
# Whether to try using cache first.
# Must be true when self.load_only = True; this function assures that.
self.use_cache = dstats.use_cache
self.cache_dir = dstats.dataset_cache_dir
self.save = save
# Lengths class object
self.lengths_obj = None
# Content shared in the DMT:
# The figure, the table, and the sufficient statistics (measurements)
self.fig_lengths = None
self.lengths_df = None
self.avg_length = None
self.std_length = None
self.uniq_counts = None
# Dict for the measurements, used in caching
self.length_stats_dict = {}
# Filenames, used in caching
self.lengths_dir = "lengths"
length_meas_json = "length_measurements.json"
lengths_fig_png = "lengths_fig.png"
lengths_df_json = "lengths_table.json"
self.length_stats_json_fid = pjoin(self.cache_dir, self.lengths_dir, length_meas_json)
self.lengths_fig_png_fid = pjoin(self.cache_dir, self.lengths_dir, lengths_fig_png)
self.lengths_df_json_fid = pjoin(self.cache_dir, self.lengths_dir, lengths_df_json)
def run_DMT_processing(self):
"""
Gets data structures for the figure, table, and measurements.
"""
# First look to see what we can load from cache.
if self.use_cache:
logs.info("Trying to load from cache...")
# Defines self.lengths_df, self.length_stats_dict, self.fig_lengths
# This is the table, the dict of measurements, and the figure
self.load_lengths_cache()
# Sets the measurements as attributes of the DMT object
self.set_attributes()
# If we do not have measurements loaded from cache...
if not self.length_stats_dict and not self.load_only:
logs.info("Preparing length results")
# Compute length statistics. Uses the Lengths class.
self.lengths_obj = self._prepare_lengths()
# Dict of measurements
self.length_stats_dict = self.lengths_obj.length_stats_dict
# Table of text and lengths
self.lengths_df = self.lengths_obj.lengths_df
# Sets the measurements in the length_stats_dict
self.set_attributes()
# Makes the figure
self.fig_lengths = make_fig_lengths(self.lengths_df)
# Finish
if self.save:
logs.info("Saving results.")
self._write_lengths_cache()
if exists(self.lengths_fig_png_fid):
# As soon as we have a figure, we redefine it as an image.
# This is a hack to handle a UI display error (TODO: file bug)
self.fig_lengths = Image.open(self.lengths_fig_png_fid)
def set_attributes(self):
if self.length_stats_dict:
self.avg_length = self.length_stats_dict[AVG]
self.std_length = self.length_stats_dict[STD]
self.uniq_counts = self.length_stats_dict[UNIQ]
else:
logs.info("No lengths stats found. =(")
def load_lengths_cache(self):
# Dataframe with <sentence, length> exists. Load it.
if exists(self.lengths_df_json_fid):
self.lengths_df = ds_utils.read_df(self.lengths_df_json_fid)
# Image exists. Load it.
if exists(self.lengths_fig_png_fid):
self.fig_lengths = Image.open(self.lengths_fig_png_fid) # mpimg.imread(self.lengths_fig_png_fid)
# Measurements exist. Load them.
if exists(self.length_stats_json_fid):
# Loads the length measurements
self.length_stats_dict = ds_utils.read_json(self.length_stats_json_fid)
def _write_lengths_cache(self):
# Writes the data structures using the corresponding filetypes.
ds_utils.make_path(pjoin(self.cache_dir, self.lengths_dir))
if self.length_stats_dict != {}:
ds_utils.write_json(self.length_stats_dict, self.length_stats_json_fid)
if isinstance(self.fig_lengths, Figure):
self.fig_lengths.savefig(self.lengths_fig_png_fid)
if isinstance(self.lengths_df, pd.DataFrame):
ds_utils.write_df(self.lengths_df, self.lengths_df_json_fid)
def _prepare_lengths(self):
"""Loads a Lengths object and computes length statistics"""
# Length object for the dataset
lengths_obj = Lengths(dataset=self.tokenized_df)
lengths_obj.prepare_lengths()
return lengths_obj
def get_filenames(self):
lengths_fid_dict = {"statistics": self.length_stats_json_fid,
"figure png": self.lengths_fig_png_fid,
"table": self.lengths_df_json_fid}
return lengths_fid_dict
class Lengths:
"""Generic class for text length processing.
Uses DataFrames for faster processing.
Given a dataframe with tokenized words in a column called TOKENIZED_TEXT,
and the text instances in a column called TEXT, compute statistics.
"""
def __init__(self, dataset):
self.dset_df = dataset
# Dict of measurements
self.length_stats_dict = {}
# Measurements
self.avg_length = None
self.std_length = None
self.num_uniq_lengths = None
# Table of lengths and sentences
self.lengths_df = None
def prepare_lengths(self):
self.lengths_df = pd.DataFrame(self.dset_df[TEXT_FIELD])
self.lengths_df[LENGTH_FIELD] = self.dset_df[TOKENIZED_FIELD].apply(len)
lengths_array = self.lengths_df[LENGTH_FIELD]
self.avg_length = statistics.mean(lengths_array)
self.std_length = statistics.stdev(lengths_array)
self.num_uniq_lengths = len(lengths_array.unique())
self.length_stats_dict = {
"average_instance_length": self.avg_length,
"standard_dev_instance_length": self.std_length,
"num_instance_lengths": self.num_uniq_lengths,
}