Spaces:
Runtime error
Runtime error
File size: 6,903 Bytes
0c131af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from scipy.spatial.distance import cosine
import argparse
import json
import pdb
import torch
import torch.nn.functional as F
import numpy as np
import time
from collections import OrderedDict
class TWCClustering:
def __init__(self):
print("In Zscore Clustering")
def compute_matrix(self,embeddings):
#print("Computing similarity matrix ...)")
embeddings= np.array(embeddings)
start = time.time()
vec_a = embeddings.T #vec_a shape (1024,)
vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows
vec_a = vec_a.T #vec_a shape becomes (,1024)
similarity_matrix = np.inner(vec_a,vec_a)
end = time.time()
time_val = (end-start)*1000
#print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes")
return similarity_matrix
def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold):
run_index = pivot_index
picked_arr = []
while (run_index < len(embeddings)):
if (matrix[pivot_index][run_index] >= threshold):
picked_arr.append(run_index)
run_index += 1
return picked_arr
def update_picked_dict_arr(self,picked_dict,arr):
for i in range(len(arr)):
picked_dict[arr[i]] = 1
def update_picked_dict(self,picked_dict,in_dict):
for key in in_dict:
picked_dict[key] = 1
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
center_index = pivot_index
center_score = 0
center_dict = {}
for i in range(len(arr)):
node_i_index = arr[i]
running_score = 0
temp_dict = {}
for j in range(len(arr)):
node_j_index = arr[j]
cosine_dist = matrix[node_i_index][node_j_index]
if ((cosine_dist < threshold) and strict_cluster):
continue
running_score += cosine_dist
temp_dict[node_j_index] = cosine_dist
if (running_score > center_score):
center_index = node_i_index
center_dict = temp_dict
center_score = running_score
sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True))
return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d}
def update_overlap_stats(self,overlap_dict,cluster_info):
arr = list(cluster_info["neighs"].keys())
for val in arr:
if (val not in overlap_dict):
overlap_dict[val] = 1
else:
overlap_dict[val] += 1
def bucket_overlap(self,overlap_dict):
bucket_dict = {}
for key in overlap_dict:
if (overlap_dict[key] not in bucket_dict):
bucket_dict[overlap_dict[key]] = 1
else:
bucket_dict[overlap_dict[key]] += 1
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
return sorted_d
def merge_clusters(self,ref_cluster,curr_cluster):
dup_arr = ref_cluster.copy()
for j in range(len(curr_cluster)):
if (curr_cluster[j] not in dup_arr):
ref_cluster.append(curr_cluster[j])
def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
candidates = []
for i in range(len(embeddings)):
if (i in picked_dict):
continue
zscore = mean + threshold*std
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
candidates.append(arr)
self.update_picked_dict_arr(picked_dict,arr)
# Merge arrays to create non-overlapping sets
run_index_i = 0
while (run_index_i < len(candidates)):
ref_cluster = candidates[run_index_i]
run_index_j = run_index_i + 1
found = False
while (run_index_j < len(candidates)):
curr_cluster = candidates[run_index_j]
for k in range(len(curr_cluster)):
if (curr_cluster[k] in ref_cluster):
self.merge_clusters(ref_cluster,curr_cluster)
candidates.pop(run_index_j)
found = True
run_index_i = 0
break
if (found):
break
else:
run_index_j += 1
if (not found):
run_index_i += 1
zscore = mean + threshold*std
for i in range(len(candidates)):
arr = candidates[i]
cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
cluster_dict["clusters"].append(cluster_info)
return {}
def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
picked_dict = {}
overlap_dict = {}
zscore = mean + threshold*std
for i in range(len(embeddings)):
if (i in picked_dict):
continue
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
self.update_picked_dict(picked_dict,cluster_info["neighs"])
self.update_overlap_stats(overlap_dict,cluster_info)
cluster_dict["clusters"].append(cluster_info)
sorted_d = self.bucket_overlap(overlap_dict)
return sorted_d
def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
is_overlapped = True if clustering_type == "overlapped" else False
matrix = self.compute_matrix(embeddings)
mean = np.mean(matrix)
std = np.std(matrix)
zscores = []
inc = 0
value = mean
while (value < 1):
zscores.append({"threshold":inc,"cosine":round(value,2)})
inc += 1
value = mean + inc*std
#print("In clustering:",round(std,2),zscores)
cluster_dict = {}
cluster_dict["clusters"] = []
if (is_overlapped):
sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
else:
sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
return cluster_dict
|