chat-gradio / personalized-chat-bot /models /personality_clustering.py
j.gilyazev
add personalized-chat-bot
0766044
raw
history blame
3.04 kB
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import pickle
class PersonalityClustering:
DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2'
@property
def sentence_transformer(self):
"""Ленивая инициализация sentence_transformer."""
if not self.__sentence_transformer:
self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device)
return self.__sentence_transformer
@property
def clustering(self):
"""Ленивая инициализация кластеризации."""
if not self.__clustering:
self.__clustering = KMeans(n_clusters=self.n_clusters)
return self.__clustering
def __init__(self, n_clusters=None, device='cpu', model_name=None):
if model_name is None:
self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER
else:
self.model_name = model_name
self.device = device
self.n_clusters = n_clusters
self._cluster_centers = None
self.__clustering = None
self.__sentence_transformer = None
def load(self, path):
with open(path, "rb") as f:
self.__clustering, self._cluster_centers = pickle.load(f)
def save(self, path):
with open(path, "wb") as f:
pickle.dump((self.__clustering, self._cluster_centers), f)
def fit(self, personalities):
personalities = np.array(list(personalities))
train_embeddings = self.sentence_transformer.encode(personalities)
clusters = self.clustering.fit_predict(train_embeddings)
persona_cluster_centers = []
for clust, center in enumerate(self.clustering.cluster_centers_):
cur_clust_embed = train_embeddings[clusters == clust]
cur_clust_personalities = personalities[clusters == clust]
min_distance_to_center = np.inf
persona_center = None
for embed, persona in zip(cur_clust_embed, cur_clust_personalities):
cur_distance_to_center = np.linalg.norm(embed - center)
if cur_distance_to_center < min_distance_to_center:
min_distance_to_center = cur_distance_to_center
persona_center = persona
persona_cluster_centers.append(persona_center)
self._cluster_centers = np.array(persona_cluster_centers)
return self
def predict(self, personalities):
personalities = np.array(list(personalities))
embeddings = self.sentence_transformer.encode(personalities)
clusters = self.clustering.predict(embeddings)
return clusters
def predict_nearest_personality(self, personalities):
clusters = self.predict(personalities)
return np.array([self._cluster_centers[clust] for clust in clusters])
def fit_predict(self, personalities):
self.fit(personalities)
return self.predict(personalities)