Spaces:
Runtime error
Runtime error
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.cluster import KMeans | |
import pickle | |
class PersonalityClustering: | |
DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2' | |
def sentence_transformer(self): | |
"""Ленивая инициализация sentence_transformer.""" | |
if not self.__sentence_transformer: | |
self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device) | |
return self.__sentence_transformer | |
def clustering(self): | |
"""Ленивая инициализация кластеризации.""" | |
if not self.__clustering: | |
self.__clustering = KMeans(n_clusters=self.n_clusters) | |
return self.__clustering | |
def __init__(self, n_clusters=None, device='cpu', model_name=None): | |
if model_name is None: | |
self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER | |
else: | |
self.model_name = model_name | |
self.device = device | |
self.n_clusters = n_clusters | |
self._cluster_centers = None | |
self.__clustering = None | |
self.__sentence_transformer = None | |
def load(self, path): | |
with open(path, "rb") as f: | |
self.__clustering, self._cluster_centers = pickle.load(f) | |
def save(self, path): | |
with open(path, "wb") as f: | |
pickle.dump((self.__clustering, self._cluster_centers), f) | |
def fit(self, personalities): | |
personalities = np.array(list(personalities)) | |
train_embeddings = self.sentence_transformer.encode(personalities) | |
clusters = self.clustering.fit_predict(train_embeddings) | |
persona_cluster_centers = [] | |
for clust, center in enumerate(self.clustering.cluster_centers_): | |
cur_clust_embed = train_embeddings[clusters == clust] | |
cur_clust_personalities = personalities[clusters == clust] | |
min_distance_to_center = np.inf | |
persona_center = None | |
for embed, persona in zip(cur_clust_embed, cur_clust_personalities): | |
cur_distance_to_center = np.linalg.norm(embed - center) | |
if cur_distance_to_center < min_distance_to_center: | |
min_distance_to_center = cur_distance_to_center | |
persona_center = persona | |
persona_cluster_centers.append(persona_center) | |
self._cluster_centers = np.array(persona_cluster_centers) | |
return self | |
def predict(self, personalities): | |
personalities = np.array(list(personalities)) | |
embeddings = self.sentence_transformer.encode(personalities) | |
clusters = self.clustering.predict(embeddings) | |
return clusters | |
def predict_nearest_personality(self, personalities): | |
clusters = self.predict(personalities) | |
return np.array([self._cluster_centers[clust] for clust in clusters]) | |
def fit_predict(self, personalities): | |
self.fit(personalities) | |
return self.predict(personalities) | |