Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
•
d8ab532
1
Parent(s):
7c5b4e0
Continuing cache minimizing in new repository. Please see https://github.com/huggingface/DataMeasurements for full history
Browse files
data_measurements/dataset_statistics.py
CHANGED
@@ -15,14 +15,15 @@
|
|
15 |
import json
|
16 |
import logging
|
17 |
import statistics
|
|
|
18 |
from os import mkdir
|
19 |
from os.path import exists, isdir
|
20 |
from os.path import join as pjoin
|
21 |
-
from pathlib import Path
|
22 |
|
23 |
import nltk
|
24 |
import numpy as np
|
25 |
import pandas as pd
|
|
|
26 |
import plotly.express as px
|
27 |
import plotly.figure_factory as ff
|
28 |
import plotly.graph_objects as go
|
@@ -59,8 +60,6 @@ logs.propagate = False
|
|
59 |
|
60 |
if not logs.handlers:
|
61 |
|
62 |
-
Path('./log_files').mkdir(exist_ok=True)
|
63 |
-
|
64 |
# Logging info to log file
|
65 |
file = logging.FileHandler("./log_files/dataset_statistics.log")
|
66 |
fileformat = logging.Formatter("%(asctime)s:%(message)s")
|
@@ -263,7 +262,12 @@ class DatasetStatisticsCacheClass:
|
|
263 |
self.text_duplicate_counts_df_fid = pjoin(
|
264 |
self.cache_path, "text_dup_counts_df.feather"
|
265 |
)
|
|
|
|
|
|
|
|
|
266 |
self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
|
|
|
267 |
|
268 |
def get_base_dataset(self):
|
269 |
"""Gets a pointer to the truncated base dataset object."""
|
@@ -307,7 +311,11 @@ class DatasetStatisticsCacheClass:
|
|
307 |
write_df(self.text_dup_counts_df, self.text_duplicate_counts_df_fid)
|
308 |
write_json(self.general_stats_dict, self.general_stats_fid)
|
309 |
|
310 |
-
def load_or_prepare_text_lengths(self, use_cache=False):
|
|
|
|
|
|
|
|
|
311 |
if len(self.tokenized_df) == 0:
|
312 |
self.tokenized_df = self.do_tokenization()
|
313 |
self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(len)
|
@@ -320,12 +328,28 @@ class DatasetStatisticsCacheClass:
|
|
320 |
statistics.stdev(self.tokenized_df[self.our_length_field]), 1
|
321 |
)
|
322 |
self.fig_tok_length = make_fig_lengths(self.tokenized_df, self.our_length_field)
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
self.
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
# get vocab with word counts
|
331 |
def load_or_prepare_vocab(self, use_cache=True, save=True):
|
@@ -341,7 +365,7 @@ class DatasetStatisticsCacheClass:
|
|
341 |
):
|
342 |
logs.info("Reading vocab from cache")
|
343 |
self.load_vocab()
|
344 |
-
self.vocab_counts_filtered_df =
|
345 |
else:
|
346 |
logs.info("Calculating vocab afresh")
|
347 |
if len(self.tokenized_df) == 0:
|
@@ -352,7 +376,7 @@ class DatasetStatisticsCacheClass:
|
|
352 |
word_count_df = count_vocab_frequencies(self.tokenized_df)
|
353 |
logs.info("Making dfs with proportion.")
|
354 |
self.vocab_counts_df = calc_p_word(word_count_df)
|
355 |
-
self.vocab_counts_filtered_df =
|
356 |
if save:
|
357 |
logs.info("Writing out.")
|
358 |
write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
|
@@ -365,17 +389,31 @@ class DatasetStatisticsCacheClass:
|
|
365 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
|
366 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
367 |
|
368 |
-
def load_or_prepare_zipf(self, use_cache=False):
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
# TODO: Read zipf data so that the vocab is there.
|
371 |
with open(self.zipf_fid, "r") as f:
|
372 |
zipf_dict = json.load(f)
|
373 |
self.z = Zipf()
|
374 |
self.z.load(zipf_dict)
|
|
|
|
|
|
|
375 |
else:
|
376 |
self.z = Zipf(self.vocab_counts_df)
|
377 |
-
|
378 |
-
|
|
|
|
|
379 |
|
380 |
def prepare_general_text_stats(self):
|
381 |
text_nan_count = int(self.tokenized_df.isnull().sum().sum())
|
@@ -476,6 +514,8 @@ class DatasetStatisticsCacheClass:
|
|
476 |
self.label_field = label_field
|
477 |
|
478 |
def load_or_prepare_labels(self, use_cache=False, save=True):
|
|
|
|
|
479 |
"""
|
480 |
Extracts labels from the Dataset
|
481 |
:param use_cache:
|
@@ -483,9 +523,17 @@ class DatasetStatisticsCacheClass:
|
|
483 |
"""
|
484 |
# extracted labels
|
485 |
if len(self.label_field) > 0:
|
486 |
-
if use_cache and exists(self.
|
|
|
|
|
487 |
# load extracted labels
|
488 |
self.label_dset = load_from_disk(self.label_dset_fid)
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
else:
|
490 |
self.get_base_dataset()
|
491 |
self.label_dset = self.dset.map(
|
@@ -495,14 +543,14 @@ class DatasetStatisticsCacheClass:
|
|
495 |
batched=True,
|
496 |
remove_columns=list(self.dset.features),
|
497 |
)
|
|
|
|
|
|
|
|
|
498 |
if save:
|
499 |
# save extracted label instances
|
500 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
501 |
-
|
502 |
-
|
503 |
-
self.fig_labels = make_fig_labels(
|
504 |
-
self.label_df, self.label_names, OUR_LABEL_FIELD
|
505 |
-
)
|
506 |
|
507 |
def load_vocab(self):
|
508 |
with open(self.vocab_counts_df_fid, "rb") as f:
|
@@ -796,7 +844,7 @@ def calc_p_word(word_count_df):
|
|
796 |
return vocab_counts_df
|
797 |
|
798 |
|
799 |
-
def
|
800 |
# TODO: Add warnings (which words are missing) to log file?
|
801 |
filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
|
802 |
errors="ignore")
|
@@ -808,6 +856,12 @@ def filter_words(vocab_counts_df):
|
|
808 |
|
809 |
## Figures ##
|
810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
def make_fig_lengths(tokenized_df, length_field):
|
813 |
fig_tok_length = px.histogram(
|
@@ -815,7 +869,6 @@ def make_fig_lengths(tokenized_df, length_field):
|
|
815 |
)
|
816 |
return fig_tok_length
|
817 |
|
818 |
-
|
819 |
def make_fig_labels(label_df, label_names, label_field):
|
820 |
labels = label_df[label_field].unique()
|
821 |
label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
|
@@ -896,6 +949,89 @@ def make_zipf_fig(vocab_counts_df, z):
|
|
896 |
return fig
|
897 |
|
898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
## Input/Output ###
|
900 |
|
901 |
|
@@ -949,7 +1085,6 @@ def write_json(json_dict, json_fid):
|
|
949 |
with open(json_fid, "w", encoding="utf-8") as f:
|
950 |
json.dump(json_dict, f)
|
951 |
|
952 |
-
|
953 |
def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
|
954 |
"""
|
955 |
Saves the calculated nPMI statistics to their output files.
|
@@ -969,7 +1104,6 @@ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
|
|
969 |
with open(subgroup_cooc_fid, "w+") as f:
|
970 |
subgroup_cooc_df.to_csv(f)
|
971 |
|
972 |
-
|
973 |
def write_zipf_data(z, zipf_fid):
|
974 |
zipf_dict = {}
|
975 |
zipf_dict["xmin"] = int(z.xmin)
|
|
|
15 |
import json
|
16 |
import logging
|
17 |
import statistics
|
18 |
+
import torch
|
19 |
from os import mkdir
|
20 |
from os.path import exists, isdir
|
21 |
from os.path import join as pjoin
|
|
|
22 |
|
23 |
import nltk
|
24 |
import numpy as np
|
25 |
import pandas as pd
|
26 |
+
import plotly
|
27 |
import plotly.express as px
|
28 |
import plotly.figure_factory as ff
|
29 |
import plotly.graph_objects as go
|
|
|
60 |
|
61 |
if not logs.handlers:
|
62 |
|
|
|
|
|
63 |
# Logging info to log file
|
64 |
file = logging.FileHandler("./log_files/dataset_statistics.log")
|
65 |
fileformat = logging.Formatter("%(asctime)s:%(message)s")
|
|
|
262 |
self.text_duplicate_counts_df_fid = pjoin(
|
263 |
self.cache_path, "text_dup_counts_df.feather"
|
264 |
)
|
265 |
+
self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
|
266 |
+
self.fig_labels_fid = pjoin(self.cache_path, "fig_labels.json")
|
267 |
+
self.node_list_fid = pjoin(self.cache_path, "node_list.th")
|
268 |
+
self.fig_tree_fid = pjoin(self.cache_path, "fig_tree.json")
|
269 |
self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
|
270 |
+
self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
|
271 |
|
272 |
def get_base_dataset(self):
|
273 |
"""Gets a pointer to the truncated base dataset object."""
|
|
|
311 |
write_df(self.text_dup_counts_df, self.text_duplicate_counts_df_fid)
|
312 |
write_json(self.general_stats_dict, self.general_stats_fid)
|
313 |
|
314 |
+
def load_or_prepare_text_lengths(self, use_cache=False, save=True):
|
315 |
+
# TODO: Everything here can be read from cache; it's in a transitory
|
316 |
+
# state atm where just the fig is cached. Clean up.
|
317 |
+
if use_cache and exists(self.fig_tok_length_fid):
|
318 |
+
self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
|
319 |
if len(self.tokenized_df) == 0:
|
320 |
self.tokenized_df = self.do_tokenization()
|
321 |
self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(len)
|
|
|
328 |
statistics.stdev(self.tokenized_df[self.our_length_field]), 1
|
329 |
)
|
330 |
self.fig_tok_length = make_fig_lengths(self.tokenized_df, self.our_length_field)
|
331 |
+
if save:
|
332 |
+
write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
|
333 |
+
|
334 |
+
def load_or_prepare_embeddings(self, use_cache=False, save=True):
|
335 |
+
if use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
|
336 |
+
self.node_list = torch.load(self.node_list_fid)
|
337 |
+
self.fig_tree = read_plotly(self.fig_tree_fid)
|
338 |
+
elif use_cache and exists(self.node_list_fid):
|
339 |
+
self.node_list = torch.load(self.node_list_fid)
|
340 |
+
self.fig_tree = make_tree_plot(self.node_list,
|
341 |
+
self.text_dset)
|
342 |
+
if save:
|
343 |
+
write_plotly(self.fig_tree, self.fig_tree_fid)
|
344 |
+
else:
|
345 |
+
self.embeddings = Embeddings(self, use_cache=use_cache)
|
346 |
+
self.embeddings.make_hierarchical_clustering()
|
347 |
+
self.node_list = self.embeddings.node_list
|
348 |
+
self.fig_tree = make_tree_plot(self.node_list,
|
349 |
+
self.text_dset)
|
350 |
+
if save:
|
351 |
+
torch.save(self.node_list, self.node_list_fid)
|
352 |
+
write_plotly(self.fig_tree, self.fig_tree_fid)
|
353 |
|
354 |
# get vocab with word counts
|
355 |
def load_or_prepare_vocab(self, use_cache=True, save=True):
|
|
|
365 |
):
|
366 |
logs.info("Reading vocab from cache")
|
367 |
self.load_vocab()
|
368 |
+
self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
|
369 |
else:
|
370 |
logs.info("Calculating vocab afresh")
|
371 |
if len(self.tokenized_df) == 0:
|
|
|
376 |
word_count_df = count_vocab_frequencies(self.tokenized_df)
|
377 |
logs.info("Making dfs with proportion.")
|
378 |
self.vocab_counts_df = calc_p_word(word_count_df)
|
379 |
+
self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
|
380 |
if save:
|
381 |
logs.info("Writing out.")
|
382 |
write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
|
|
|
389 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
|
390 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
391 |
|
392 |
+
def load_or_prepare_zipf(self, use_cache=False, save=True):
|
393 |
+
# TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
|
394 |
+
# when only reading from cache. Either the UI should use it, or it should
|
395 |
+
# be removed when reading in cache
|
396 |
+
if use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
|
397 |
+
with open(self.zipf_fid, "r") as f:
|
398 |
+
zipf_dict = json.load(f)
|
399 |
+
self.z = Zipf()
|
400 |
+
self.z.load(zipf_dict)
|
401 |
+
self.zipf_fig = read_plotly(self.zipf_fig_fid)
|
402 |
+
elif use_cache and exists(self.zipf_fid):
|
403 |
# TODO: Read zipf data so that the vocab is there.
|
404 |
with open(self.zipf_fid, "r") as f:
|
405 |
zipf_dict = json.load(f)
|
406 |
self.z = Zipf()
|
407 |
self.z.load(zipf_dict)
|
408 |
+
self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
|
409 |
+
if save:
|
410 |
+
write_plotly(self.zipf_fig, self.zipf_fig_fid)
|
411 |
else:
|
412 |
self.z = Zipf(self.vocab_counts_df)
|
413 |
+
self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
|
414 |
+
if save:
|
415 |
+
write_zipf_data(self.z, self.zipf_fid)
|
416 |
+
write_plotly(self.zipf_fig, self.zipf_fig_fid)
|
417 |
|
418 |
def prepare_general_text_stats(self):
|
419 |
text_nan_count = int(self.tokenized_df.isnull().sum().sum())
|
|
|
514 |
self.label_field = label_field
|
515 |
|
516 |
def load_or_prepare_labels(self, use_cache=False, save=True):
|
517 |
+
# TODO: This is in a transitory state for creating fig cache.
|
518 |
+
# Clean up to be caching and reading everything correctly.
|
519 |
"""
|
520 |
Extracts labels from the Dataset
|
521 |
:param use_cache:
|
|
|
523 |
"""
|
524 |
# extracted labels
|
525 |
if len(self.label_field) > 0:
|
526 |
+
if use_cache and exists(self.fig_labels_fid):
|
527 |
+
self.fig_labels = read_plotly(self.fig_labels_fid)
|
528 |
+
elif use_cache and exists(self.label_dset_fid):
|
529 |
# load extracted labels
|
530 |
self.label_dset = load_from_disk(self.label_dset_fid)
|
531 |
+
self.label_df = self.label_dset.to_pandas()
|
532 |
+
self.fig_labels = make_fig_labels(
|
533 |
+
self.label_df, self.label_names, OUR_LABEL_FIELD
|
534 |
+
)
|
535 |
+
if save:
|
536 |
+
write_plotly(self.fig_labels, self.fig_labels_fid)
|
537 |
else:
|
538 |
self.get_base_dataset()
|
539 |
self.label_dset = self.dset.map(
|
|
|
543 |
batched=True,
|
544 |
remove_columns=list(self.dset.features),
|
545 |
)
|
546 |
+
self.label_df = self.label_dset.to_pandas()
|
547 |
+
self.fig_labels = make_fig_labels(
|
548 |
+
self.label_df, self.label_names, OUR_LABEL_FIELD
|
549 |
+
)
|
550 |
if save:
|
551 |
# save extracted label instances
|
552 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
553 |
+
write_plotly(self.fig_labels, self.fig_labels_fid)
|
|
|
|
|
|
|
|
|
554 |
|
555 |
def load_vocab(self):
|
556 |
with open(self.vocab_counts_df_fid, "rb") as f:
|
|
|
844 |
return vocab_counts_df
|
845 |
|
846 |
|
847 |
+
def filter_vocab(vocab_counts_df):
|
848 |
# TODO: Add warnings (which words are missing) to log file?
|
849 |
filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
|
850 |
errors="ignore")
|
|
|
856 |
|
857 |
## Figures ##
|
858 |
|
859 |
+
def write_plotly(fig, fid):
|
860 |
+
write_json(plotly.io.to_json(fig), fid)
|
861 |
+
|
862 |
+
def read_plotly(fid):
|
863 |
+
fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
|
864 |
+
return fig
|
865 |
|
866 |
def make_fig_lengths(tokenized_df, length_field):
|
867 |
fig_tok_length = px.histogram(
|
|
|
869 |
)
|
870 |
return fig_tok_length
|
871 |
|
|
|
872 |
def make_fig_labels(label_df, label_names, label_field):
|
873 |
labels = label_df[label_field].unique()
|
874 |
label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
|
|
|
949 |
return fig
|
950 |
|
951 |
|
952 |
+
def make_tree_plot(node_list, text_dset):
|
953 |
+
nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
|
954 |
+
|
955 |
+
for nid, node in enumerate(node_list):
|
956 |
+
node["label"] = node.get(
|
957 |
+
"label",
|
958 |
+
f"{nid:2d} - {node['weight']:5d} items <br>"
|
959 |
+
+ "<br>".join(
|
960 |
+
[
|
961 |
+
"> " + txt[:64] + ("..." if len(txt) >= 63 else "")
|
962 |
+
for txt in list(
|
963 |
+
set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
|
964 |
+
)[:5]
|
965 |
+
]
|
966 |
+
),
|
967 |
+
)
|
968 |
+
|
969 |
+
# make plot nodes
|
970 |
+
# TODO: something more efficient than set to remove duplicates
|
971 |
+
labels = [node["label"] for node in node_list]
|
972 |
+
|
973 |
+
root = node_list[0]
|
974 |
+
root["X"] = 0
|
975 |
+
root["Y"] = 0
|
976 |
+
|
977 |
+
def rec_make_coordinates(node):
|
978 |
+
total_weight = 0
|
979 |
+
add_weight = len(node["example_ids"]) - sum(
|
980 |
+
[child["weight"] for child in node["children"]]
|
981 |
+
)
|
982 |
+
for child in node["children"]:
|
983 |
+
child["X"] = node["X"] + total_weight
|
984 |
+
child["Y"] = node["Y"] - 1
|
985 |
+
total_weight += child["weight"] + add_weight / len(node["children"])
|
986 |
+
rec_make_coordinates(child)
|
987 |
+
|
988 |
+
rec_make_coordinates(root)
|
989 |
+
|
990 |
+
E = [] # list of edges
|
991 |
+
Xn = []
|
992 |
+
Yn = []
|
993 |
+
Xe = []
|
994 |
+
Ye = []
|
995 |
+
for nid, node in enumerate(node_list):
|
996 |
+
Xn += [node["X"]]
|
997 |
+
Yn += [node["Y"]]
|
998 |
+
for child in node["children"]:
|
999 |
+
E += [(nid, nid_map[child["nid"]])]
|
1000 |
+
Xe += [node["X"], child["X"], None]
|
1001 |
+
Ye += [node["Y"], child["Y"], None]
|
1002 |
+
|
1003 |
+
# make figure
|
1004 |
+
fig = go.Figure()
|
1005 |
+
fig.add_trace(
|
1006 |
+
go.Scatter(
|
1007 |
+
x=Xe,
|
1008 |
+
y=Ye,
|
1009 |
+
mode="lines",
|
1010 |
+
line=dict(color="rgb(210,210,210)", width=1),
|
1011 |
+
hoverinfo="none",
|
1012 |
+
)
|
1013 |
+
)
|
1014 |
+
fig.add_trace(
|
1015 |
+
go.Scatter(
|
1016 |
+
x=Xn,
|
1017 |
+
y=Yn,
|
1018 |
+
mode="markers",
|
1019 |
+
name="nodes",
|
1020 |
+
marker=dict(
|
1021 |
+
symbol="circle-dot",
|
1022 |
+
size=18,
|
1023 |
+
color="#6175c1",
|
1024 |
+
line=dict(color="rgb(50,50,50)", width=1)
|
1025 |
+
# '#DB4551',
|
1026 |
+
),
|
1027 |
+
text=labels,
|
1028 |
+
hoverinfo="text",
|
1029 |
+
opacity=0.8,
|
1030 |
+
)
|
1031 |
+
)
|
1032 |
+
return fig
|
1033 |
+
|
1034 |
+
|
1035 |
## Input/Output ###
|
1036 |
|
1037 |
|
|
|
1085 |
with open(json_fid, "w", encoding="utf-8") as f:
|
1086 |
json.dump(json_dict, f)
|
1087 |
|
|
|
1088 |
def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
|
1089 |
"""
|
1090 |
Saves the calculated nPMI statistics to their output files.
|
|
|
1104 |
with open(subgroup_cooc_fid, "w+") as f:
|
1105 |
subgroup_cooc_df.to_csv(f)
|
1106 |
|
|
|
1107 |
def write_zipf_data(z, zipf_fid):
|
1108 |
zipf_dict = {}
|
1109 |
zipf_dict["xmin"] = int(z.xmin)
|