Spaces:
Runtime error
Runtime error
Merge branch 'bloom-personachat' of https://huggingface.co/spaces/hivemind-personalized-chat/chat-gradio
Browse files- personalized-chat-bot/bot_example.py +60 -0
- personalized-chat-bot/data.zip +3 -0
- personalized-chat-bot/generation_config.json +1 -0
- personalized-chat-bot/models/__init__.py +1 -0
- personalized-chat-bot/models/personality_clustering.py +74 -0
- personalized-chat-bot/personalized_chat_bot.py +65 -0
- personalized-chat-bot/prompt_paths.json +16 -0
- personalized-chat-bot/scripts/__init__.py +1 -0
- personalized-chat-bot/scripts/config_176b.json +16 -0
- personalized-chat-bot/scripts/config_6b.json +16 -0
- personalized-chat-bot/scripts/fit_personality_clustering.py +52 -0
- personalized-chat-bot/scripts/train_all.sh +11 -0
- personalized-chat-bot/scripts/train_bloom_personachat.py +123 -0
- personalized-chat-bot/util/__init__.py +0 -0
- personalized-chat-bot/util/bloom_trainer.py +91 -0
- personalized-chat-bot/util/data.py +74 -0
- personalized-chat-bot/util/dialogue_manager.py +27 -0
- personalized-chat-bot/util/metrics.py +27 -0
personalized-chat-bot/bot_example.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
from petals.client.remote_model import DistributedBloomForCausalLM
|
7 |
+
|
8 |
+
from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
|
9 |
+
from models.personality_clustering import PersonalityClustering
|
10 |
+
|
11 |
+
def load_config(path):
|
12 |
+
with open(path, 'r') as f:
|
13 |
+
config = json.load(f)
|
14 |
+
return argparse.Namespace(**config)
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
greating = 'Describe the person you want to talk:'
|
19 |
+
print(greating)
|
20 |
+
persona_description = input()
|
21 |
+
print('Cool! wait a few seconds...')
|
22 |
+
personality_clustering = PersonalityClustering()
|
23 |
+
personality_clustering.load('./data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
|
24 |
+
|
25 |
+
hook = lambda dct: {int(k): v for k, v in dct.items()}
|
26 |
+
with open('prompt_paths.json', 'r') as f:
|
27 |
+
prompt_paths = json.load(f, object_hook=hook)
|
28 |
+
|
29 |
+
pm = PersonalityManager(prompt_paths, personality_clustering)
|
30 |
+
prompt_path, closest_persona = pm.get_prompt(persona_description)
|
31 |
+
print(f'The closest personality is: {closest_persona}')
|
32 |
+
print('Wait a little longer...')
|
33 |
+
config = load_config('./scripts/config_176b.json')
|
34 |
+
|
35 |
+
model = DistributedBloomForCausalLM.from_pretrained(
|
36 |
+
config.MODEL_NAME,
|
37 |
+
pre_seq_len=config.NUM_PREFIX_TOKENS,
|
38 |
+
tuning_mode=config.TUNING_MODE
|
39 |
+
).to(config.DEVICE)
|
40 |
+
|
41 |
+
generation_config = load_config('generation_config.json')
|
42 |
+
|
43 |
+
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
|
44 |
+
tokenizer.padding_side = 'right'
|
45 |
+
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
|
46 |
+
|
47 |
+
chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
|
48 |
+
chatbot.load_prompt(prompt_path)
|
49 |
+
print('Done! You can start a dialogue.')
|
50 |
+
try:
|
51 |
+
while True:
|
52 |
+
text = input('You: ')
|
53 |
+
answer = chatbot.answer(text)
|
54 |
+
print(f'Bloom: {answer}')
|
55 |
+
except KeyboardInterrupt:
|
56 |
+
print('Thank you for the conversation!')
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
main()
|
personalized-chat-bot/data.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d73016d5eccc0eeb641f623789e6a80c601572aee825603bdfacf84c9e8f705
|
3 |
+
size 12635714
|
personalized-chat-bot/generation_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"MAX_TOKENS": 16, "TOP_K": 100, "TEMPERATURE": 0.8}
|
personalized-chat-bot/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
personalized-chat-bot/models/personality_clustering.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from sklearn.cluster import KMeans
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
|
7 |
+
class PersonalityClustering:
|
8 |
+
DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2'
|
9 |
+
|
10 |
+
@property
|
11 |
+
def sentence_transformer(self):
|
12 |
+
"""Ленивая инициализация sentence_transformer."""
|
13 |
+
if not self.__sentence_transformer:
|
14 |
+
self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device)
|
15 |
+
return self.__sentence_transformer
|
16 |
+
|
17 |
+
@property
|
18 |
+
def clustering(self):
|
19 |
+
"""Ленивая инициализация кластеризации."""
|
20 |
+
if not self.__clustering:
|
21 |
+
self.__clustering = KMeans(n_clusters=self.n_clusters)
|
22 |
+
return self.__clustering
|
23 |
+
|
24 |
+
def __init__(self, n_clusters=None, device='cpu', model_name=None):
|
25 |
+
if model_name is None:
|
26 |
+
self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER
|
27 |
+
else:
|
28 |
+
self.model_name = model_name
|
29 |
+
self.device = device
|
30 |
+
self.n_clusters = n_clusters
|
31 |
+
self._cluster_centers = None
|
32 |
+
self.__clustering = None
|
33 |
+
self.__sentence_transformer = None
|
34 |
+
|
35 |
+
def load(self, path):
|
36 |
+
with open(path, "rb") as f:
|
37 |
+
self.__clustering, self._cluster_centers = pickle.load(f)
|
38 |
+
|
39 |
+
def save(self, path):
|
40 |
+
with open(path, "wb") as f:
|
41 |
+
pickle.dump((self.__clustering, self._cluster_centers), f)
|
42 |
+
|
43 |
+
def fit(self, personalities):
|
44 |
+
personalities = np.array(list(personalities))
|
45 |
+
train_embeddings = self.sentence_transformer.encode(personalities)
|
46 |
+
clusters = self.clustering.fit_predict(train_embeddings)
|
47 |
+
persona_cluster_centers = []
|
48 |
+
for clust, center in enumerate(self.clustering.cluster_centers_):
|
49 |
+
cur_clust_embed = train_embeddings[clusters == clust]
|
50 |
+
cur_clust_personalities = personalities[clusters == clust]
|
51 |
+
min_distance_to_center = np.inf
|
52 |
+
persona_center = None
|
53 |
+
for embed, persona in zip(cur_clust_embed, cur_clust_personalities):
|
54 |
+
cur_distance_to_center = np.linalg.norm(embed - center)
|
55 |
+
if cur_distance_to_center < min_distance_to_center:
|
56 |
+
min_distance_to_center = cur_distance_to_center
|
57 |
+
persona_center = persona
|
58 |
+
persona_cluster_centers.append(persona_center)
|
59 |
+
self._cluster_centers = np.array(persona_cluster_centers)
|
60 |
+
return self
|
61 |
+
|
62 |
+
def predict(self, personalities):
|
63 |
+
personalities = np.array(list(personalities))
|
64 |
+
embeddings = self.sentence_transformer.encode(personalities)
|
65 |
+
clusters = self.clustering.predict(embeddings)
|
66 |
+
return clusters
|
67 |
+
|
68 |
+
def predict_nearest_personality(self, personalities):
|
69 |
+
clusters = self.predict(personalities)
|
70 |
+
return np.array([self._cluster_centers[clust] for clust in clusters])
|
71 |
+
|
72 |
+
def fit_predict(self, personalities):
|
73 |
+
self.fit(personalities)
|
74 |
+
return self.predict(personalities)
|
personalized-chat-bot/personalized_chat_bot.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from sklearn.neighbors import KDTree
|
5 |
+
|
6 |
+
|
7 |
+
class PersonalityManager:
|
8 |
+
def __init__(self, prompt_paths, personality_clustering):
|
9 |
+
self.prompt_paths = prompt_paths
|
10 |
+
self.personality_clustering = personality_clustering
|
11 |
+
|
12 |
+
self.persona_ids = list(prompt_paths.keys())
|
13 |
+
self.personalities = [personality_clustering._cluster_centers[i]
|
14 |
+
for i in self.persona_ids]
|
15 |
+
|
16 |
+
self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities)
|
17 |
+
self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean')
|
18 |
+
|
19 |
+
def get_prompt(self, description):
|
20 |
+
embedding = self.personality_clustering.sentence_transformer.encode([description])
|
21 |
+
dist, ind = self._nearest_neighbours.query(embedding, k=1)
|
22 |
+
persona_id = self.persona_ids[ind[0][0]]
|
23 |
+
prompt_path = self.prompt_paths[persona_id]
|
24 |
+
cluster_center = self.personality_clustering._cluster_centers[persona_id]
|
25 |
+
return prompt_path, cluster_center
|
26 |
+
|
27 |
+
|
28 |
+
class PersonalizedChatBot:
|
29 |
+
def __init__(self, model, tokenizer, prompt_path=None, generation_config=None):
|
30 |
+
self.model = model
|
31 |
+
if prompt_path is not None:
|
32 |
+
self.load_prompt(prompt_path)
|
33 |
+
self.tokenizer = tokenizer
|
34 |
+
self.separator = '\n'
|
35 |
+
self.dialog = ''
|
36 |
+
self.generation_config = generation_config
|
37 |
+
|
38 |
+
def load_prompt(self, path):
|
39 |
+
self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
|
40 |
+
|
41 |
+
def load_config(self, path):
|
42 |
+
with open(path, 'r') as f:
|
43 |
+
config = json.load(f)
|
44 |
+
self.generation_config = argparse.Namespace(**config)
|
45 |
+
|
46 |
+
def reset_dialog(self, ):
|
47 |
+
self.dialog = ''
|
48 |
+
|
49 |
+
def answer(self, phrase):
|
50 |
+
if len(phrase) == 0:
|
51 |
+
return
|
52 |
+
self.dialog += f"{phrase}{self.separator}"
|
53 |
+
inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids']
|
54 |
+
outputs = self.model.generate(
|
55 |
+
inputs,
|
56 |
+
temperature=self.generation_config.TEMPERATURE,
|
57 |
+
do_sample=True,
|
58 |
+
top_k=self.generation_config.TOP_K,
|
59 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
60 |
+
max_new_tokens=self.generation_config.MAX_TOKENS,
|
61 |
+
)
|
62 |
+
bloom_answer = self.tokenizer.batch_decode(outputs)[0]
|
63 |
+
bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0]
|
64 |
+
self.dialog += f"{bloom_answer}{self.separator}"
|
65 |
+
return bloom_answer
|
personalized-chat-bot/prompt_paths.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"113": "./data/models/176b/113_persona_prompt_embedding.pt",
|
3 |
+
"54": "./data/models/176b/54_persona_prompt_embedding.pt",
|
4 |
+
"169": "./data/models/176b/169_persona_prompt_embedding.pt",
|
5 |
+
"364": "./data/models/176b/364_persona_prompt_embedding.pt",
|
6 |
+
"214": "./data/models/176b/214_persona_prompt_embedding.pt",
|
7 |
+
"125": "./data/models/176b/125_persona_prompt_embedding.pt",
|
8 |
+
"103": "./data/models/176b/103_persona_prompt_embedding.pt",
|
9 |
+
"200": "./data/models/176b/200_persona_prompt_embedding.pt",
|
10 |
+
"296": "./data/models/176b/296_persona_prompt_embedding.pt",
|
11 |
+
"20": "./data/models/176b/20_persona_prompt_embedding.pt",
|
12 |
+
"384": "./data/models/176b/384_persona_prompt_embedding.pt",
|
13 |
+
"365": "./data/models/176b/365_persona_prompt_embedding.pt",
|
14 |
+
"451": "./data/models/176b/451_persona_prompt_embedding.pt",
|
15 |
+
"80": "./data/models/176b/80_persona_prompt_embedding.pt"
|
16 |
+
}
|
personalized-chat-bot/scripts/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
personalized-chat-bot/scripts/config_176b.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
|
3 |
+
"MODEL_NAME": "bigscience/bloom-petals",
|
4 |
+
"INITIAL_PEERS": [],
|
5 |
+
"NUM_PREFIX_TOKENS": 16,
|
6 |
+
"DEVICE": "cpu",
|
7 |
+
"BATCH_SIZE": 4,
|
8 |
+
"LR": 0.01,
|
9 |
+
"WEIGHT_DECAY": 0.0,
|
10 |
+
"NUM_SAMPLES": 1000,
|
11 |
+
"SEED": 42,
|
12 |
+
"MODEL_MAX_LENGTH": 256,
|
13 |
+
"TUNING_MODE": "ptune",
|
14 |
+
"N_EPOCH": 10,
|
15 |
+
"PADDING_SIDE": "right"
|
16 |
+
}
|
personalized-chat-bot/scripts/config_6b.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
|
3 |
+
"MODEL_NAME": "bigscience/test-bloomd-6b3",
|
4 |
+
"INITIAL_PEERS":["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"],
|
5 |
+
"NUM_PREFIX_TOKENS": 16,
|
6 |
+
"DEVICE": "cpu",
|
7 |
+
"BATCH_SIZE": 4,
|
8 |
+
"LR": 0.01,
|
9 |
+
"WEIGHT_DECAY": 0.0,
|
10 |
+
"NUM_SAMPLES": 1000,
|
11 |
+
"SEED": 42,
|
12 |
+
"MODEL_MAX_LENGTH": 256,
|
13 |
+
"TUNING_MODE": "ptune",
|
14 |
+
"N_EPOCH": 1,
|
15 |
+
"PADDING_SIDE": "right"
|
16 |
+
}
|
personalized-chat-bot/scripts/fit_personality_clustering.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datasets import load_dataset
|
3 |
+
from models.personality_clustering import PersonalityClustering
|
4 |
+
import os
|
5 |
+
|
6 |
+
"""Пример запуска
|
7 |
+
python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500
|
8 |
+
"""
|
9 |
+
|
10 |
+
PERSONACHAT_DATASET = "bavard/personachat_truecased"
|
11 |
+
|
12 |
+
|
13 |
+
def load_persona_chat_personalities(personachat_dataset):
|
14 |
+
dataset = load_dataset(personachat_dataset)
|
15 |
+
train_personalities = [sent for persona in dataset['train']['personality']
|
16 |
+
for sent in persona]
|
17 |
+
test_personalities = [sent for persona in dataset['train']['personality']
|
18 |
+
for sent in persona]
|
19 |
+
personalities = list(set(train_personalities) | set(test_personalities))
|
20 |
+
return personalities
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args(args=None):
|
24 |
+
parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.")
|
25 |
+
|
26 |
+
parser.add_argument('-clustering-path', '--clustering-path', type=str,
|
27 |
+
help='Path to clustering data.')
|
28 |
+
parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500,
|
29 |
+
help='The number of clusters to form.')
|
30 |
+
parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False)
|
31 |
+
args = parser.parse_args(args)
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = parse_args()
|
37 |
+
personalities = load_persona_chat_personalities(PERSONACHAT_DATASET)
|
38 |
+
print('Data loaded')
|
39 |
+
model = PersonalityClustering(n_clusters=args.n_clusters)
|
40 |
+
print('Model fitting')
|
41 |
+
model.fit(personalities)
|
42 |
+
print('Model fitted')
|
43 |
+
if args.model_name is None:
|
44 |
+
model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl'
|
45 |
+
else:
|
46 |
+
model_name = args.model_name
|
47 |
+
model.save(os.path.join(args.clustering_path, model_name))
|
48 |
+
print(f'{model_name} saved')
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
personalized-chat-bot/scripts/train_all.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#python -m scripts.train_bloom_personachat --persona-ids 113 54 169 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
4 |
+
#python -m scripts.train_bloom_personachat --persona-ids 364 214 125 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
5 |
+
#python -m scripts.train_bloom_personachat --persona-ids 103 200 296 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
6 |
+
#python -m scripts.train_bloom_personachat --persona-ids 20 384 365 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
7 |
+
#python -m scripts.train_bloom_personachat --persona-ids 208 43 99 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
8 |
+
#python -m scripts.train_bloom_personachat --persona-ids 426 477 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
9 |
+
python -m scripts.train_bloom_personachat --persona-ids 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
10 |
+
|
11 |
+
python -m scripts.train_bloom_personachat --persona-ids 329 402 382 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
|
personalized-chat-bot/scripts/train_bloom_personachat.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch.cuda
|
4 |
+
from datasets import load_dataset
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import transformers
|
8 |
+
from torch.utils.data import Subset
|
9 |
+
import wandb
|
10 |
+
import numpy as np
|
11 |
+
import gc
|
12 |
+
|
13 |
+
from models.personality_clustering import PersonalityClustering
|
14 |
+
from util.bloom_trainer import BloomTrainer
|
15 |
+
from util.data import PersonaChatDataset
|
16 |
+
from util.metrics import perplexity
|
17 |
+
|
18 |
+
from petals.client.remote_model import DistributedBloomForCausalLM
|
19 |
+
|
20 |
+
"""Пример запуска
|
21 |
+
python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
|
22 |
+
"""
|
23 |
+
|
24 |
+
DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
|
25 |
+
MAX_VAL_DATA_SIZE = 4
|
26 |
+
|
27 |
+
|
28 |
+
def load_config(path):
|
29 |
+
with open(path, 'r') as f:
|
30 |
+
config = json.load(f)
|
31 |
+
return argparse.Namespace(**config)
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
args = parse_args()
|
36 |
+
persona_clustering = PersonalityClustering()
|
37 |
+
persona_clustering.load(args.clustering_model_path)
|
38 |
+
|
39 |
+
config = load_config(args.config)
|
40 |
+
|
41 |
+
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
|
42 |
+
tokenizer.padding_side = config.PADDING_SIDE
|
43 |
+
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
|
44 |
+
|
45 |
+
dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
|
46 |
+
personachat_train_dataset = PersonaChatDataset(persona_clustering,
|
47 |
+
dataset['train'],
|
48 |
+
tokenizer)
|
49 |
+
personachat_val_dataset = PersonaChatDataset(persona_clustering,
|
50 |
+
dataset['validation'],
|
51 |
+
tokenizer)
|
52 |
+
|
53 |
+
for id in args.persona_ids:
|
54 |
+
prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
|
55 |
+
train_dataset = personachat_train_dataset[id]
|
56 |
+
val_dataset = personachat_val_dataset[id]
|
57 |
+
honest_validation = True
|
58 |
+
if len(val_dataset) < 4:
|
59 |
+
val_dataset = personachat_train_dataset[id]
|
60 |
+
honest_validation = False
|
61 |
+
# для ускорения обрежем размер валидации до некоторой границы
|
62 |
+
if len(val_dataset) > MAX_VAL_DATA_SIZE:
|
63 |
+
subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
|
64 |
+
val_dataset = Subset(val_dataset, subset_indexes)
|
65 |
+
# train_dataset.shuffle()
|
66 |
+
|
67 |
+
wandb_run = wandb.init(
|
68 |
+
project=args.wandb_project,
|
69 |
+
config={
|
70 |
+
'lr': config.LR,
|
71 |
+
'batch_size': config.BATCH_SIZE,
|
72 |
+
'persona_id': id,
|
73 |
+
'device': config.DEVICE,
|
74 |
+
'model_name': config.MODEL_NAME,
|
75 |
+
'n_epoch': config.N_EPOCH,
|
76 |
+
'honest_validation': honest_validation
|
77 |
+
},
|
78 |
+
name=f'id{id}',
|
79 |
+
reinit=True
|
80 |
+
)
|
81 |
+
if len(config.INITIAL_PEERS) == 0:
|
82 |
+
model = DistributedBloomForCausalLM.from_pretrained(
|
83 |
+
config.MODEL_NAME,
|
84 |
+
pre_seq_len=config.NUM_PREFIX_TOKENS,
|
85 |
+
tuning_mode=config.TUNING_MODE
|
86 |
+
).to(config.DEVICE)
|
87 |
+
else:
|
88 |
+
model = DistributedBloomForCausalLM.from_pretrained(
|
89 |
+
config.MODEL_NAME,
|
90 |
+
initial_peers=config.INITIAL_PEERS,
|
91 |
+
pre_seq_len=config.NUM_PREFIX_TOKENS,
|
92 |
+
tuning_mode=config.TUNING_MODE
|
93 |
+
).to(config.DEVICE)
|
94 |
+
|
95 |
+
trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
|
96 |
+
trainer.train()
|
97 |
+
eval_perplexity = trainer.evaluate(perplexity)
|
98 |
+
trainer.save_model(prompt_path)
|
99 |
+
wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})
|
100 |
+
|
101 |
+
del model
|
102 |
+
gc.collect()
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
|
106 |
+
def parse_args(args=None):
|
107 |
+
parser = argparse.ArgumentParser(add_help=True,
|
108 |
+
description="bloom training script")
|
109 |
+
parser.add_argument('--persona-ids', type=int, nargs='+',
|
110 |
+
help='Ids of persona')
|
111 |
+
parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
|
112 |
+
default=DEFAULT_CLUSTERING_MODEL,
|
113 |
+
help='Path to clustering model')
|
114 |
+
parser.add_argument('--config', type=str, help='Path to training config file')
|
115 |
+
parser.add_argument('--prompt-path', type=str,
|
116 |
+
help='Path to dir with trained soft prompts')
|
117 |
+
parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
|
118 |
+
args = parser.parse_args(args)
|
119 |
+
return args
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
main()
|
personalized-chat-bot/util/__init__.py
ADDED
File without changes
|
personalized-chat-bot/util/bloom_trainer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from transformers import get_scheduler
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
from util.metrics import perplexity
|
11 |
+
|
12 |
+
|
13 |
+
class BloomTrainer:
|
14 |
+
DEFAULT_VAL_FREQ = 5
|
15 |
+
ITERATION_LIMIT = 150
|
16 |
+
|
17 |
+
def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
|
18 |
+
self.model = model
|
19 |
+
self.config = config
|
20 |
+
self.train_dataset = train_dataset
|
21 |
+
self.val_dataset = val_dataset
|
22 |
+
self.wandb_run = wandb_run
|
23 |
+
self.val_freq = val_freq
|
24 |
+
if self.val_freq is None:
|
25 |
+
self.val_freq = self.DEFAULT_VAL_FREQ
|
26 |
+
self.prompt_path = prompt_path
|
27 |
+
|
28 |
+
self.best_loss = np.inf
|
29 |
+
|
30 |
+
self.train_loader = DataLoader(self.train_dataset,
|
31 |
+
shuffle=True,
|
32 |
+
batch_size=config.BATCH_SIZE,
|
33 |
+
drop_last=True)
|
34 |
+
self.val_loader = DataLoader(self.val_dataset,
|
35 |
+
shuffle=True,
|
36 |
+
batch_size=config.BATCH_SIZE,
|
37 |
+
drop_last=False)
|
38 |
+
|
39 |
+
self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
|
40 |
+
|
41 |
+
self.lr_scheduler = get_scheduler(
|
42 |
+
name="linear",
|
43 |
+
optimizer=self.optimizer,
|
44 |
+
num_warmup_steps=0,
|
45 |
+
num_training_steps= len(self.train_loader) * self.config.N_EPOCH
|
46 |
+
)
|
47 |
+
|
48 |
+
def train(self):
|
49 |
+
self.model.train()
|
50 |
+
iter_counter = 0
|
51 |
+
for epoch in range(self.config.N_EPOCH):
|
52 |
+
for batch in self.train_loader:
|
53 |
+
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
|
54 |
+
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
|
55 |
+
outputs = self.model(**batch)
|
56 |
+
loss = outputs.loss
|
57 |
+
loss.backward()
|
58 |
+
self.optimizer.step()
|
59 |
+
self.lr_scheduler.step()
|
60 |
+
self.optimizer.zero_grad()
|
61 |
+
self.wandb_run.log({'loss': loss})
|
62 |
+
iter_counter += 1
|
63 |
+
if (iter_counter + 1) % self.val_freq == 0:
|
64 |
+
eval_perplexity = self.evaluate(perplexity)
|
65 |
+
self.wandb_run.log({'perplexity': eval_perplexity})
|
66 |
+
if loss.item() < self.best_loss:
|
67 |
+
self.best_loss = loss.item()
|
68 |
+
self.save_model(self.prompt_path)
|
69 |
+
print('Model saved')
|
70 |
+
if iter_counter >= self.ITERATION_LIMIT:
|
71 |
+
return
|
72 |
+
|
73 |
+
def evaluate(self, eval_fn):
|
74 |
+
logits = []
|
75 |
+
labels = []
|
76 |
+
self.model.eval()
|
77 |
+
with torch.no_grad():
|
78 |
+
for batch in self.val_loader:
|
79 |
+
batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
|
80 |
+
'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
|
81 |
+
outputs = self.model(**batch)
|
82 |
+
labels.extend(batch['input_ids'])
|
83 |
+
logits.extend(outputs.logits)
|
84 |
+
metric = eval_fn(logits, labels)
|
85 |
+
return metric
|
86 |
+
|
87 |
+
def save_model(self, path):
|
88 |
+
torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)
|
89 |
+
|
90 |
+
def load_model(self, path):
|
91 |
+
self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
|
personalized-chat-bot/util/data.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from joblib import Parallel, delayed
|
5 |
+
|
6 |
+
|
7 |
+
class OnePersonaDataset(Dataset):
|
8 |
+
def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.data = data
|
12 |
+
if len(data) == 0:
|
13 |
+
self.input_ids = []
|
14 |
+
self.history = []
|
15 |
+
self.labels = []
|
16 |
+
return
|
17 |
+
|
18 |
+
if positive_candidates:
|
19 |
+
self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
|
20 |
+
self.labels = np.ones(len(self.history), dtype=int)
|
21 |
+
else:
|
22 |
+
self.history = [row['history'] + [candidate, ] for row in data
|
23 |
+
for candidate in row['candidates']]
|
24 |
+
self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
|
25 |
+
for row in data)
|
26 |
+
self.labels = np.array(self.labels, dtype=int)
|
27 |
+
|
28 |
+
if transforms is None:
|
29 |
+
self.history = ["\n".join(item) for item in self.history]
|
30 |
+
else:
|
31 |
+
self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
|
32 |
+
self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]
|
33 |
+
|
34 |
+
def __getitem__(self, idx):
|
35 |
+
return {'input_ids': self.input_ids[idx],
|
36 |
+
'labels': self.input_ids[idx],
|
37 |
+
'example': self.history[idx],
|
38 |
+
'class': self.labels[idx]}
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.data)
|
42 |
+
|
43 |
+
|
44 |
+
class PersonaChatDataset(Dataset):
|
45 |
+
DEFAULT_DATASET_NAME = "bavard/personachat_truecased"
|
46 |
+
|
47 |
+
def __init__(self, clustering, dataset, tokenizer):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.dataset = dataset
|
51 |
+
self.clustering = clustering
|
52 |
+
|
53 |
+
all_personalities = list(set([sent for item in self.dataset
|
54 |
+
for sent in item['personality']]))
|
55 |
+
predicted_centers = self.clustering.predict(all_personalities)
|
56 |
+
self.all_personalities_to_id = {persona: center
|
57 |
+
for persona, center in zip(all_personalities, predicted_centers)}
|
58 |
+
self.personalities = self.clustering._cluster_centers
|
59 |
+
|
60 |
+
subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]
|
61 |
+
|
62 |
+
for i in range(len(self.dataset)):
|
63 |
+
item = self.dataset[i]
|
64 |
+
cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
|
65 |
+
for persona_id in cur_persona_ids:
|
66 |
+
subdataset_data_by_personality[persona_id].append(item)
|
67 |
+
|
68 |
+
self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]
|
69 |
+
|
70 |
+
def __getitem__(self, persona_id):
|
71 |
+
return self.subdatasets[persona_id]
|
72 |
+
|
73 |
+
def __len__(self, ):
|
74 |
+
return len(self.datasets)
|
personalized-chat-bot/util/dialogue_manager.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DistilBertForSequenceClassification
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class DialogueManagerModel(nn.Module):
|
5 |
+
DEFAULT_MODEL = "distilbert-base-uncased"
|
6 |
+
|
7 |
+
def __init__(self, n_classes, model_name=None, device='cpu'):
|
8 |
+
super().__init__()
|
9 |
+
if model_name is None:
|
10 |
+
self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL)
|
11 |
+
else:
|
12 |
+
raise NotImplementedError()
|
13 |
+
self.model.to(device)
|
14 |
+
self.n_classes = n_classes
|
15 |
+
self.freeze_layers()
|
16 |
+
self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes,
|
17 |
+
device=device)
|
18 |
+
|
19 |
+
for param in self.model.classifier.parameters():
|
20 |
+
param.requires_grad = True
|
21 |
+
|
22 |
+
def freeze_layers(self):
|
23 |
+
for param in self.model.parameters():
|
24 |
+
param.requires_grad = False
|
25 |
+
|
26 |
+
def forward(self, X):
|
27 |
+
return self.model(X)
|
personalized-chat-bot/util/metrics.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def _perplexity(logits, labels, pad_token=3):
|
7 |
+
for i in range(len(labels)-1, -1, -1):
|
8 |
+
if labels[i] != pad_token:
|
9 |
+
last_not_pad_id = i
|
10 |
+
break
|
11 |
+
logits = logits[:last_not_pad_id + 1]
|
12 |
+
labels = labels[:last_not_pad_id + 1]
|
13 |
+
log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
|
14 |
+
log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
|
15 |
+
l = np.mean(log_probas)
|
16 |
+
return 2 ** (-l)
|
17 |
+
|
18 |
+
|
19 |
+
def perplexity(logits, labels, pad_token=3):
|
20 |
+
pp = []
|
21 |
+
if isinstance(logits, torch.Tensor):
|
22 |
+
logits = logits.detach().cpu().numpy()
|
23 |
+
if isinstance(labels, torch.Tensor):
|
24 |
+
labels = labels.detach().cpu().numpy()
|
25 |
+
for cur_logits, cur_labels in zip(logits, labels):
|
26 |
+
pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
|
27 |
+
return np.mean(pp)
|