{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from list_questions import load_questions\n", "from extract_keywords import extract_keywords, extract_keywords2\n", "db_name = 'omnidesk-ai-chatgpt-questions.sqlite'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sentence_transformers import InputExample\n", "import random\n", "\n", "def get_user_question(q):\n", " keywords = extract_keywords2(q['query'])\n", " return ' '.join([q['question'].strip(), ' '.join(keywords)]).lower()\n", " \n", "def get_system_question(q):\n", " return q['query'].lower()\n", " \n", "def get_negative_system_question(q, all_questions):\n", " negative_q = random.choice(list(filter(lambda q2: q['query'] != q2['query'], all_questions)))\n", " return negative_q['query'].lower()\n", "\n", "def input_example_generator():\n", " all_questions = list(load_questions(db_name))\n", " for q in all_questions:\n", " yield InputExample(texts=[get_user_question(q), get_system_question(q)], label=1.0)\n", " yield InputExample(texts=[get_user_question(q), get_negative_system_question(q, all_questions)], label=0.0)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import IterableDataset, DataLoader\n", "\n", "additional_examples = [\n", " InputExample(texts=['добрый день', 'добрый день, здравствуйте'], label=1.0),\n", " InputExample(texts=['здравствуйте', 'добрый день, здравствуйте'], label=1.0),\n", " InputExample(texts=['привет', 'добрый день, здравствуйте'], label=1.0),\n", " InputExample(texts=['спасибо', 'спасибо, до свидания'], label=1.0),\n", " InputExample(texts=['до свидания', 'спасибо, до свидания'], label=1.0),\n", " InputExample(texts=['не понял', 'некорректный ответ, не понял'], label=1.0),\n", " InputExample(texts=['некорректный ответ', 'некорректный ответ, не понял'], label=1.0),\n", " InputExample(texts=['как убрать ошибку', 'как убрать ошибку'], label=1.0),\n", " InputExample(texts=['как устранить ошибку', 'как убрать ошибку'], label=1.0),\n", " InputExample(texts=['как решить проблему с ошибкой', 'как убрать ошибку'], label=1.0),\n", " InputExample(texts=['есть ли способ устранить ошибку', 'как убрать ошибку'], label=1.0),\n", " InputExample(texts=['каким образом можно избавиться от ошибки', 'как убрать ошибку'], label=1.0),\n", " InputExample(texts=['позови человека', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " InputExample(texts=['позови сотрудника', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " InputExample(texts=['позови менеджера', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " InputExample(texts=['позови оператора', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " InputExample(texts=['оператор', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " InputExample(texts=['человек', 'позови человека сотрудника менеджера оператора'], label=1.0),\n", " \n", " # special cases\n", " InputExample(texts=['можете подсказать, что делать с ошибкой', 'как убрать ошибку'], label=4.0),\n", " InputExample(texts=['что произойдет при удалении оплаты cloudpayments', 'cloudpayments перенос оплаты в платежных модулях на примере модуля cloudpayments что произойдет при удалении оплаты'], label=1.0),\n", " InputExample(texts=['превышен лимит количества контактов unisender', 'экспорт сегментов в unisender ошибка превышен лимит количества контактов для текущего превышен лимит количества контактов'], label=1.0),\n", " InputExample(texts=['не отображаются тарифы', 'не передаются тарифы'], label=0.0),\n", " \n", " # ???\n", " InputExample(texts=['почему количество пользователей отличается', 'почему clientid отличается'], label=0.0),\n", " InputExample(texts=['что означает галка \\'доставка курьером\\'', 'что означает галка доставка курьером'], label=1.0),\n", " \n", " InputExample(texts=['почта россии', 'яндекс доставка'], label=0.0),\n", " InputExample(texts=['почта россии', 'яндекс метрика'], label=0.0),\n", " InputExample(texts=['яндекс доставка', 'яндекс метрика'], label=0.0),\n", " InputExample(texts=['unisender', 'яндекс доставка'], label=0.0),\n", " InputExample(texts=['альфабанк', 'яндекс доставка'], label=0.0),\n", " InputExample(texts=['почта россии', 'яндекс аудитории'], label=0.0),\n", " InputExample(texts=['sipuni', 'cloudpayments'], label=0.0),\n", " InputExample(texts=['sipuni', 'facebook'], label=0.0),\n", " InputExample(texts=['robokassa', 'вконтакте'], label=0.0),\n", " InputExample(texts=['robokassa', 'digital pipeline'], label=0.0),\n", " InputExample(texts=['facebook', 'вконтакте'], label=0.0),\n", " InputExample(texts=['facebook', 'mailchimp'], label=0.0),\n", " InputExample(texts=['почта россии', 'cloudpayments'], label=0.0),\n", "]\n", "\n", "train_dataloader = DataLoader(list(input_example_generator()) + additional_examples, batch_size=16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pretrain" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-08-22 14:47:41.400087: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "from sentence_transformers import CrossEncoder\n", "model = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sentence_transformers import InputExample\n", "pretrain_samples = [\n", " #InputExample(texts=['тест', 'тест'], label=1.0),\n", " InputExample(texts=['пока', 'до свидания'], label=1.0),\n", " InputExample(texts=['как настроить модуль', 'как настроить модуль'], label=1.0),\n", " InputExample(texts=['как настроить модуль яндекс доставка', 'как настроить модуль почта россии'], label=0.0),\n", " InputExample(texts=['как настроить модуль почта россии', 'как настроить модуль robokassa'], label=0.0),\n", " InputExample(texts=['как настроить модуль яндекс доставка', 'как настроить модуль robokassa'], label=0.0),\n", " # InputExample(texts=['ошибка сервиса доставки почта россии', 'ошибка сервиса почта россии'], label=1.0),\n", " InputExample(texts=['ошибка дата отгрузки, полученная от яндекс.доставки', 'ошибка даты отгрузки яндекс доставки'], label=1.0),\n", " InputExample(texts=['ошибка дата отгрузки, полученная от яндекс.доставки', 'яндекс доставка ошибка сервиса доставки при выборе терминала отгрузки'], label=0.0),\n", "]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "pretrain_dataloader = DataLoader(pretrain_samples)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8492aa648e8b4165a2696c46262135aa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/4 [00:00