Spaces:
Runtime error
Runtime error
import argparse | |
from datasets import load_dataset | |
from models.personality_clustering import PersonalityClustering | |
import os | |
"""Пример запуска | |
python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500 | |
""" | |
PERSONACHAT_DATASET = "bavard/personachat_truecased" | |
def load_persona_chat_personalities(personachat_dataset): | |
dataset = load_dataset(personachat_dataset) | |
train_personalities = [sent for persona in dataset['train']['personality'] | |
for sent in persona] | |
test_personalities = [sent for persona in dataset['train']['personality'] | |
for sent in persona] | |
personalities = list(set(train_personalities) | set(test_personalities)) | |
return personalities | |
def parse_args(args=None): | |
parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.") | |
parser.add_argument('-clustering-path', '--clustering-path', type=str, | |
help='Path to clustering data.') | |
parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500, | |
help='The number of clusters to form.') | |
parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False) | |
args = parser.parse_args(args) | |
return args | |
def main(): | |
args = parse_args() | |
personalities = load_persona_chat_personalities(PERSONACHAT_DATASET) | |
print('Data loaded') | |
model = PersonalityClustering(n_clusters=args.n_clusters) | |
print('Model fitting') | |
model.fit(personalities) | |
print('Model fitted') | |
if args.model_name is None: | |
model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl' | |
else: | |
model_name = args.model_name | |
model.save(os.path.join(args.clustering_path, model_name)) | |
print(f'{model_name} saved') | |
if __name__ == '__main__': | |
main() |