Spaces:
Runtime error
Runtime error
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gzip | |
import json | |
import math | |
import os | |
from os.path import exists | |
from os.path import join as pjoin | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import torch | |
import transformers | |
from datasets import load_dataset | |
from huggingface_hub import HfApi | |
from tqdm import tqdm | |
# from .dataset_utils import prepare_clustering_dataset | |
pd.options.display.max_colwidth = 256 | |
_CACHE_DIR = "cache_dir" | |
_DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
_MAX_MERGE = 20000000 # to run on 64GB RAM laptop | |
def sentence_mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[ | |
0 | |
] # First element of model_output contains all token embeddings | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
# get nearest neighbors of a centroid by dot product | |
def get_examplars(example_ids, centroid, embeddings, dset, n_examplars): | |
example_embeds = embeddings[example_ids] | |
example_scores = torch.mv(example_embeds, centroid) | |
s_scores, s_ids = example_scores.sort(dim=-1, descending=True) | |
examplars = [ | |
(example_ids[i.item()], s.item()) | |
for i, s in zip(s_ids[:n_examplars], s_scores[:n_examplars]) | |
] | |
res = [] | |
for eid, score in examplars: | |
dct = dict(dset[eid]) | |
dct["score"] = score | |
res += [dct] | |
return res | |
# order node children so that the large ones are in the middle | |
# makes visualization more balanced | |
def pretty_order(nodes, node_ids): | |
sorted_ids = sorted(node_ids, key=lambda nid: nodes[nid]["weight"]) | |
sorted_a = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 0] | |
sorted_b = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 1] | |
sorted_b.reverse() | |
return sorted_a + sorted_b | |
def make_tree_plot(node_list, root_id, max_depth=-1): | |
# make plot nodes | |
plot_nodes = [{} for _ in node_list] | |
root = { | |
"parent_id": -1, | |
"node_id": root_id, | |
"label": node_list[root_id]["hover_text"], | |
"weight": node_list[root_id]["weight"], | |
"num_leaves": 0, | |
"children_ids": node_list[root_id]["children_ids"], | |
"Xmin": 0, | |
"Y": 0, | |
} | |
plot_nodes[root_id] = root | |
root_depth = node_list[root_id]["depth"] | |
def rec_make_coordinates(node): | |
total_weight = 0 | |
recurse = (max_depth == -1) or ( | |
node_list[node["node_id"]]["depth"] - root_depth < max_depth - 1 | |
) | |
for cid in node["children_ids"]: | |
plot_nodes[cid] = { | |
"parent_id": node["node_id"], | |
"node_id": cid, | |
"label": node_list[cid]["hover_text"], | |
"weight": node_list[cid]["weight"], | |
"children_ids": node_list[cid]["children_ids"] if recurse else [], | |
"Xmin": node["Xmin"] + total_weight, | |
"Y": node["Y"] - 1, | |
} | |
plot_nodes[cid]["num_leaves"] = 1 if len(plot_nodes[cid]["children_ids"]) == 0 else 0 | |
rec_make_coordinates(plot_nodes[cid]) | |
total_weight += plot_nodes[cid]["num_leaves"] | |
node["num_leaves"] += plot_nodes[cid]["num_leaves"] | |
node["Xmax"] = node["Xmin"] + node["num_leaves"] | |
node["X"] = node["Xmin"] + (node["num_leaves"] / 2) | |
rec_make_coordinates(root) | |
subtree_nodes = [node for node in plot_nodes if len(node) > 0] | |
nid_map = dict([(node["node_id"], nid) for nid, node in enumerate(subtree_nodes)]) | |
labels = [node["label"] for node in subtree_nodes] | |
E = [] # list of edges | |
Xn = [] | |
Yn = [] | |
Xe = [] | |
Ye = [] | |
for nid, node in enumerate(subtree_nodes): | |
Xn += [node["X"]] | |
Yn += [node["Y"]] | |
for cid in node["children_ids"]: | |
child = plot_nodes[cid] | |
E += [(nid, nid_map[child["node_id"]])] | |
Xe += [node["X"], child["X"], None] | |
Ye += [node["Y"], child["Y"], None] | |
# make figure | |
fig = go.Figure() | |
fig.add_trace( | |
go.Scatter( | |
x=Xe, | |
y=Ye, | |
mode="lines", | |
name="", | |
line=dict(color="rgb(210,210,210)", width=1), | |
hoverinfo="none", | |
) | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=Xn, | |
y=Yn, | |
mode="markers", | |
name="nodes", | |
marker=dict( | |
symbol="circle-dot", | |
size=18, | |
color="#6175c1", | |
line=dict(color="rgb(50,50,50)", width=1) | |
# '#DB4551', | |
), | |
text=labels, | |
hoverinfo="text", | |
opacity=0.8, | |
) | |
) | |
fig.layout.showlegend = False | |
return fig | |
class ClusteringBuilder: | |
def __init__( | |
self, | |
dataset_name, | |
config_name, | |
split_name, | |
input_field_path, | |
label_name, | |
num_rows, | |
model_name=_DEFAULT_MODEL, | |
): | |
"""Item embeddings and clustering""" | |
self.dataset_name = dataset_name | |
self.config_name = config_name | |
self.split_name = split_name | |
self.input_field_path = input_field_path | |
self.label_name = label_name | |
self.num_rows = num_rows | |
self.cache_path_list = [ | |
_CACHE_DIR, | |
dataset_name.replace("/", "---"), | |
f"{'default' if config_name is None else config_name}", | |
f"{'train' if split_name is None else split_name}", | |
f"field-{'->'.join(input_field_path)}-label-{label_name}", | |
f"{num_rows}_rows", | |
model_name.replace("/", "---"), | |
] | |
self.cache_path = pjoin(*self.cache_path_list) | |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
self.model_name = model_name | |
# prepare embeddings for the dataset | |
def set_model(self): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) | |
self.model = transformers.AutoModel.from_pretrained(self.model_name).to( | |
self.device | |
) | |
def set_features_dataset(self, use_streaming, use_auth_token, use_dataset): | |
dset, dset_path = prepare_clustering_dataset( | |
dataset_name=self.dataset_name, | |
input_field_path=self.input_field_path, | |
label_name=self.label_name, | |
config_name=self.config_name, | |
split_name=self.split_name, | |
num_rows=self.num_rows, | |
use_streaming=use_streaming, | |
use_auth_token=use_auth_token, | |
use_dataset=use_dataset, | |
) | |
self.features_dset = dset | |
def compute_feature_embeddings(self, sentences): | |
batch = self.tokenizer( | |
sentences, padding=True, truncation=True, return_tensors="pt" | |
) | |
batch = {k: v.to(self.device) for k, v in batch.items()} | |
with torch.no_grad(): | |
model_output = self.model(**batch) | |
sentence_embeds = sentence_mean_pooling( | |
model_output, batch["attention_mask"] | |
) | |
sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True) | |
return sentence_embeds | |
def set_embeddings_dataset(self): | |
def batch_embed(examples): | |
return { | |
"embedding": [ | |
embed.tolist() | |
for embed in self.compute_feature_embeddings(examples["field"]) | |
] | |
} | |
if not exists(self.cache_path): | |
os.mkdir(self.cache_path) | |
self.embeddings_dset = self.features_dset.map( | |
batch_embed, | |
batched=True, | |
batch_size=32, | |
cache_file_name=pjoin(self.cache_path, "embeddings_dset"), | |
) | |
def prepare_embeddings( | |
self, | |
use_streaming=True, | |
use_auth_token=None, | |
use_dataset=None, | |
): | |
self.set_model() | |
self.set_features_dataset(use_streaming, use_auth_token, use_dataset) | |
self.set_embeddings_dataset() | |
# make cluster tree | |
def prepare_merges(self, batch_size, low_thres): | |
self.embeddings = torch.Tensor(self.embeddings_dset["embedding"]) | |
all_indices = torch.LongTensor(torch.Size([0, 2])) | |
all_scores = torch.Tensor(torch.Size([0])) | |
n_batches = math.ceil(self.embeddings_dset.num_rows / batch_size) | |
for a in range(n_batches): | |
for b in tqdm(range(a, n_batches)): | |
cos_scores = torch.mm( | |
self.embeddings[a * batch_size : (a + 1) * batch_size], | |
self.embeddings[b * batch_size : (b + 1) * batch_size].t(), | |
) | |
if a == b: | |
cos_scores = cos_scores.triu(diagonal=1) | |
merge_indices = torch.nonzero(cos_scores > low_thres) | |
merge_indices[:, 0] += a * batch_size | |
merge_indices[:, 1] += b * batch_size | |
merge_scores = cos_scores[cos_scores > low_thres] | |
all_indices = torch.cat([all_indices, merge_indices], dim=0) | |
all_scores = torch.cat([all_scores, merge_scores], dim=0) | |
self.sorted_scores, sorted_score_ids = all_scores.sort(dim=0, descending=True) | |
self.sorted_scores = self.sorted_scores[:_MAX_MERGE] | |
sorted_score_ids = sorted_score_ids[:_MAX_MERGE] | |
self.sorted_indices = all_indices[sorted_score_ids] | |
def make_starting_nodes(self, identical_threshold): | |
identical_indices = self.sorted_indices[ | |
self.sorted_scores >= identical_threshold | |
] | |
identical_inter = identical_indices[ | |
identical_indices[:, 1].sort(stable=True).indices | |
] | |
identical_sorted = identical_inter[ | |
identical_inter[:, 0].sort(stable=True).indices | |
] | |
self.parents = {} | |
for a_pre, b_pre in identical_sorted: | |
a = a_pre.item() | |
b = b_pre.item() | |
while self.parents.get(a, -1) != -1: | |
a = self.parents[a] | |
self.parents[b] = a | |
self.duplicates = {} | |
for a, b in self.parents.items(): | |
self.duplicates[b] = self.duplicates.get(b, []) + [a] | |
self.nodes = {} | |
for node_id in range(self.features_dset.num_rows): | |
if node_id in self.parents: | |
continue | |
else: | |
self.nodes[node_id] = { | |
"node_id": node_id, | |
"parent_id": -1, | |
"children": [], | |
"children_ids": [], | |
"example_ids": [node_id], | |
"weight": 1, | |
"merge_threshold": 0.98, | |
"depth": 0, | |
} | |
def make_merge_nodes(self, identical_threshold, thres_step): | |
new_node_id = self.features_dset.num_rows | |
current_thres = identical_threshold | |
depth = 1 | |
merge_ids = self.sorted_indices[self.sorted_scores < identical_threshold] | |
merge_scores = self.sorted_scores[self.sorted_scores < identical_threshold] | |
for (node_id_a, node_id_b), merge_score in tqdm( | |
zip(merge_ids, merge_scores), total=len(merge_ids) | |
): | |
if merge_score.item() < current_thres: | |
current_thres -= thres_step | |
merge_a = node_id_a.item() | |
while self.parents.get(merge_a, -1) != -1: | |
merge_a = self.parents[merge_a] | |
self.parents[node_id_a] = merge_a | |
merge_b = node_id_b.item() | |
while self.parents.get(merge_b, -1) != -1: | |
merge_b = self.parents[merge_b] | |
self.parents[node_id_b] = merge_b | |
if merge_a == merge_b: | |
continue | |
else: | |
merge_b, merge_a = sorted([merge_a, merge_b]) | |
node_a = self.nodes[merge_a] | |
node_b = self.nodes[merge_b] | |
if (node_a["depth"]) > 0 and min( | |
node_a["merge_threshold"], node_b["merge_threshold"] | |
) == current_thres: | |
node_a["depth"] = max(node_a["depth"], node_b["depth"]) | |
node_a["weight"] += node_b["weight"] | |
node_a["children_ids"] += ( | |
node_b["children_ids"] | |
if node_b["depth"] > 0 | |
else [node_b["node_id"]] | |
) | |
for cid in node_b["children_ids"]: | |
self.nodes[cid]["parent_id"] = node_a["node_id"] | |
self.parents[cid] = node_a["node_id"] | |
node_b["parent_id"] = node_a["node_id"] | |
self.parents[node_b["node_id"]] = node_a["node_id"] | |
else: | |
new_nid = new_node_id | |
new_node_id += 1 | |
new_node = { | |
"node_id": new_nid, | |
"parent_id": -1, | |
"children_ids": [node_a["node_id"], node_b["node_id"]], | |
"example_ids": [], | |
"weight": node_a["weight"] + node_b["weight"], | |
"merge_threshold": current_thres, | |
"depth": max(node_a["depth"], node_b["depth"]) + 1, | |
} | |
depth = max(depth, new_node["depth"]) | |
node_a["parent_id"] = new_nid | |
node_b["parent_id"] = new_nid | |
self.parents[node_a["node_id"]] = new_nid | |
self.parents[node_b["node_id"]] = new_nid | |
self.parents[node_id_a] = new_nid | |
self.parents[node_id_b] = new_nid | |
self.nodes[new_nid] = new_node | |
return new_node_id | |
def collapse_nodes(self, node, min_weight): | |
children = [ | |
self.collapse_nodes(self.nodes[cid], min_weight) | |
for cid in node["children_ids"] | |
if self.nodes[cid]["weight"] >= min_weight | |
] | |
extras = [ | |
lid | |
for cid in node["children_ids"] | |
if self.nodes[cid]["weight"] < min_weight | |
for lid in self.collapse_nodes(self.nodes[cid], min_weight)["example_ids"] | |
] + node["example_ids"] | |
extras_embed = ( | |
torch.cat( | |
[self.embeddings[eid][None, :] for eid in extras], | |
dim=0, | |
).sum(dim=0) | |
if len(extras) > 0 | |
else torch.zeros(self.embeddings.shape[-1]) | |
) | |
if len(children) == 0: | |
node["extras"] = extras | |
node["children_ids"] = [] | |
node["example_ids"] = extras | |
node["embedding_sum"] = extras_embed | |
elif len(children) == 1: | |
node["extras"] = extras + children[0]["extras"] | |
node["children_ids"] = children[0]["children_ids"] | |
node["example_ids"] = extras + children[0]["example_ids"] | |
node["embedding_sum"] = extras_embed + children[0]["embedding_sum"] | |
else: | |
node["extras"] = extras | |
node["children_ids"] = [child["node_id"] for child in children] | |
node["example_ids"] = extras + [ | |
eid for child in children for eid in child["example_ids"] | |
] | |
node["embedding_sum"] = ( | |
extras_embed | |
+ torch.cat( | |
[child["embedding_sum"][None, :] for child in children], | |
dim=0, | |
).sum(dim=0) | |
) | |
assert ( | |
len(node["example_ids"]) == node["weight"] | |
), f"stuck at {node['node_id']} - {len(node['example_ids'])} - {node['weight']}" | |
return node | |
def finalize_node(self, node, parent_id, n_examplars, with_labels): | |
new_node_id = len(self.tree_node_list) | |
new_node = { | |
"node_id": new_node_id, | |
"parent_id": parent_id, | |
"depth": 0 | |
if parent_id == -1 | |
else self.tree_node_list[parent_id]["depth"] + 1, | |
"merged_at": node["merge_threshold"], | |
"weight": node["weight"], | |
"is_extra": False, | |
} | |
self.tree_node_list += [new_node] | |
centroid = node["embedding_sum"] / node["embedding_sum"].norm() | |
new_node["centroid"] = centroid.tolist() | |
new_node["examplars"] = get_examplars( | |
node["example_ids"], | |
centroid, | |
self.embeddings, | |
self.features_dset, | |
n_examplars, | |
) | |
label_counts = {} | |
if with_labels: | |
for eid in node["example_ids"]: | |
label = self.features_dset[eid]["label"] | |
label_counts[label] = label_counts.get(label, 0) + 1 | |
new_node["label_counts"] = sorted( | |
label_counts.items(), key=lambda x: x[1], reverse=True | |
) | |
if len(node["children_ids"]) == 0: | |
new_node["children_ids"] = [] | |
else: | |
children = [ | |
self.nodes[cid] | |
for cid in pretty_order(self.nodes, node["children_ids"]) | |
] | |
children_ids = [ | |
self.finalize_node(child, new_node_id, n_examplars, with_labels) | |
for child in children | |
] | |
new_node["children_ids"] = children_ids | |
if len(node["extras"]) > 0: | |
extra_node = { | |
"node_id": len(self.tree_node_list), | |
"parent_id": new_node_id, | |
"depth": new_node["depth"] + 1, | |
"merged_at": node["merge_threshold"], | |
"weight": len(node["extras"]), | |
"is_extra": True, | |
"centroid": new_node["centroid"], | |
"examplars": get_examplars( | |
node["extras"], | |
centroid, | |
self.embeddings, | |
self.features_dset, | |
n_examplars, | |
), | |
} | |
self.tree_node_list += [extra_node] | |
label_counts = {} | |
if with_labels: | |
for eid in node["extras"]: | |
label = self.features_dset[eid]["label"] | |
label_counts[label] = label_counts.get(label, 0) + 1 | |
extra_node["label_counts"] = sorted( | |
label_counts.items(), key=lambda x: x[1], reverse=True | |
) | |
extra_node["children_ids"] = [] | |
new_node["children_ids"] += [extra_node["node_id"]] | |
return new_node_id | |
def make_hover_text(self, num_examples=5, text_width=64, with_labels=False): | |
for nid, node in enumerate(self.tree_node_list): | |
line_list = [ | |
f"Node {nid:3d} - {node['weight']:6d} items - Linking threshold: {node['merged_at']:.2f}" | |
] | |
for examplar in node["examplars"][:num_examples]: | |
line_list += [ | |
f"{examplar['ids']:6d}:{examplar['score']:.2f} - {examplar['field'][:text_width]}" | |
+ (f" - {examplar['label']}" if with_labels else "") | |
] | |
if with_labels: | |
line_list += ["Label distribution"] | |
for label, count in node["label_counts"]: | |
line_list += [f" - label: {label} - {count} items"] | |
node["hover_text"] = "<br>".join(line_list) | |
def build_tree( | |
self, | |
batch_size=10000, | |
low_thres=0.5, | |
identical_threshold=0.95, | |
thres_step=0.05, | |
min_weight=10, | |
n_examplars=25, | |
hover_examples=5, | |
hover_text_width=64, | |
): | |
self.prepare_merges(batch_size, low_thres) | |
self.make_starting_nodes(identical_threshold) | |
# make a root to join all trees | |
root_node_id = self.make_merge_nodes(identical_threshold, thres_step) | |
top_nodes = [node for node in self.nodes.values() if node["parent_id"] == -1] | |
root_node = { | |
"node_id": root_node_id, | |
"parent_id": -1, | |
"children_ids": [node["node_id"] for node in top_nodes], | |
"example_ids": [], | |
"weight": sum([node["weight"] for node in top_nodes]), | |
"merge_threshold": -1.0, | |
"depth": 1 + max([node["depth"] for node in top_nodes]), | |
} | |
for node in top_nodes: | |
node["parent_id"] = root_node_id | |
self.nodes[root_node_id] = root_node | |
_ = self.collapse_nodes(root_node, min_weight) | |
self.tree_node_list = [] | |
self.finalize_node( | |
root_node, | |
-1, | |
n_examplars, | |
with_labels=(self.label_name is not None), | |
) | |
self.make_hover_text( | |
num_examples=hover_examples, | |
text_width=hover_text_width, | |
with_labels=(self.label_name is not None), | |
) | |
def push_to_hub(self, use_auth_token=None, file_name=None): | |
path_list = self.cache_path_list | |
name = "tree" if file_name is None else file_name | |
tree_file = pjoin(pjoin(*path_list), f"{name}.jsonl.gz") | |
fout = gzip.open(tree_file, "w") | |
for node in tqdm(self.tree_node_list): | |
_ = fout.write((json.dumps(node) + "\n").encode("utf-8")) | |
fout.close() | |
api = HfApi() | |
file_loc = api.upload_file( | |
path_or_fileobj=tree_file, | |
path_in_repo=pjoin(pjoin(*path_list[1:]), f"{name}.jsonl.gz"), | |
repo_id="yjernite/datasets_clusters", | |
token=use_auth_token, | |
repo_type="dataset", | |
) | |
return file_loc | |
class Clustering: | |
def __init__( | |
self, | |
dataset_name, | |
config_name, | |
split_name, | |
input_field_path, | |
label_name, | |
num_rows, | |
n_examplars=10, | |
model_name=_DEFAULT_MODEL, | |
file_name=None, | |
max_depth_subtree=3, | |
): | |
self.dataset_name = dataset_name | |
self.config_name = config_name | |
self.split_name = split_name | |
self.input_field_path = input_field_path | |
self.label_name = label_name | |
self.num_rows = num_rows | |
self.model_name = model_name | |
self.n_examplars = n_examplars | |
self.file_name = "tree" if file_name is None else file_name | |
self.repo_path_list = [ | |
dataset_name.replace("/", "---"), | |
f"{'default' if config_name is None else config_name}", | |
f"{'train' if split_name is None else split_name}", | |
f"field-{'->'.join(input_field_path)}-label-{label_name}", | |
f"{num_rows}_rows", | |
model_name.replace("/", "---"), | |
f"{self.file_name}.jsonl.gz", | |
] | |
self.repo_path = pjoin(*self.repo_path_list) | |
self.node_list = load_dataset( | |
"yjernite/datasets_clusters", data_files=[self.repo_path] | |
)["train"] | |
self.node_reps = [{} for node in self.node_list] | |
self.max_depth_subtree = max_depth_subtree | |
def set_full_tree(self): | |
self.node_reps[0]["tree"] = self.node_reps[0].get( | |
"tree", | |
make_tree_plot( | |
self.node_list, | |
0, | |
), | |
) | |
def get_full_tree(self): | |
self.set_full_tree() | |
return self.node_reps[0]["tree"] | |
def set_node_subtree(self, node_id): | |
self.node_reps[node_id]["subtree"] = self.node_reps[node_id].get( | |
"subtree", | |
make_tree_plot( | |
self.node_list, | |
node_id, | |
self.max_depth_subtree, | |
), | |
) | |
def get_node_subtree(self, node_id): | |
self.set_node_subtree(node_id) | |
return self.node_reps[node_id]["subtree"] | |
def set_node_examplars(self, node_id): | |
self.node_reps[node_id]["examplars"] = self.node_reps[node_id].get( | |
"examplars", | |
pd.DataFrame( | |
[ | |
{ | |
"id": exple["ids"], | |
"score": exple["score"], | |
"field": exple["field"], | |
"label": exple.get("label", "N/A"), | |
} | |
for exple in self.node_list[node_id]["examplars"] | |
][: self.n_examplars] | |
), | |
) | |
def get_node_examplars(self, node_id): | |
self.set_node_examplars(node_id) | |
return self.node_reps[node_id]["examplars"] | |
def set_node_label_chart(self, node_id): | |
self.node_reps[node_id]["label_chart"] = self.node_reps[node_id].get( | |
"label_chart", | |
px.pie( | |
values=[ct for lab, ct in self.node_list[node_id]["label_counts"]], | |
names=[ | |
f"Label {lab}" | |
for lab, ct in self.node_list[node_id]["label_counts"] | |
], | |
color_discrete_sequence=px.colors.sequential.Rainbow, | |
width=400, | |
height=400, | |
), | |
) | |
def get_node_label_chart(self, node_id): | |
self.set_node_label_chart(node_id) | |
return self.node_reps[node_id]["label_chart"] | |